41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
|
|
class TorchAutocast:
|
|
"""TorchAutocast utility class.
|
|
Allows you to enable and disable autocast. This is specially useful
|
|
when dealing with different architectures and clusters with different
|
|
levels of support.
|
|
|
|
Args:
|
|
enabled (bool): Whether to enable torch.autocast or not.
|
|
args: Additional args for torch.autocast.
|
|
kwargs: Additional kwargs for torch.autocast
|
|
"""
|
|
def __init__(self, enabled: bool, *args, **kwargs):
|
|
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
|
|
|
def __enter__(self):
|
|
if self.autocast is None:
|
|
return
|
|
try:
|
|
self.autocast.__enter__()
|
|
except RuntimeError:
|
|
device = self.autocast.device
|
|
dtype = self.autocast.fast_dtype
|
|
raise RuntimeError(
|
|
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
|
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
|
)
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
if self.autocast is None:
|
|
return
|
|
self.autocast.__exit__(*args, **kwargs)
|