Select Git revision
ModularSemanticBeliefTracker.py
-
Carel van Niekerk authoredCarel van Niekerk authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
modeling_utils.py 1.59 KiB
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch.cuda.amp as amp
import transformers
from transformers import GPT2LMHeadModel
# reference: https://pytorch.org/docs/master/notes/amp_examples.html
class AmpGPT2LMHeadModel(GPT2LMHeadModel):
if TYPE_CHECKING:
# For IDE's code hinting
forward = GPT2LMHeadModel.forward
else:
def forward(self, *args, **kwargs):
with amp.autocast():
return super().forward(*args, **kwargs)
def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
if model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable()
else:
warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
class AmpHelper:
"""
References:
https://pytorch.org/docs/master/notes/amp_examples.html
"""
def __init__(self, use_amp=True):
self.use_amp = use_amp
self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
self.scaler = amp.GradScaler()
def backward(self, loss):
if self.use_amp:
return self.scaler.scale(loss).backward()
else:
return loss.backward()
def step(self, optimizer):
if self.use_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
def might_unscale_(self, optimizer):
if self.use_amp:
# Unscales the gradients of optimizer's assigned params in-place
self.scaler.unscale_(optimizer)