Skip to content
Snippets Groups Projects
Select Git revision
  • 9c3acead9a45d3454843d0f2bd46f30459f20178
  • master default protected
  • release/1.1.4
  • release/1.1.3
  • release/1.1.1
  • 1.4.2
  • 1.4.1
  • 1.4.0
  • 1.3.0
  • 1.2.1
  • 1.2.0
  • 1.1.5
  • 1.1.4
  • 1.1.3
  • 1.1.1
  • 1.1.0
  • 1.0.9
  • 1.0.8
  • 1.0.7
  • v1.0.5
  • 1.0.5
21 results

RecursiveFunctionHandler.java

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    modeling_utils.py 1.59 KiB
    import warnings
    from contextlib import nullcontext
    from typing import TYPE_CHECKING
    import torch.cuda.amp as amp
    import transformers
    from transformers import GPT2LMHeadModel
    
    
    # reference: https://pytorch.org/docs/master/notes/amp_examples.html
    class AmpGPT2LMHeadModel(GPT2LMHeadModel):
        if TYPE_CHECKING:
            # For IDE's code hinting
            forward = GPT2LMHeadModel.forward
        else:
            def forward(self, *args, **kwargs):
                with amp.autocast():
                    return super().forward(*args, **kwargs)
    
    
    def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
        if model.supports_gradient_checkpointing:
            model.gradient_checkpointing_enable()
        else:
            warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
    
    
    class AmpHelper:
        """
        References:
            https://pytorch.org/docs/master/notes/amp_examples.html
        """
        def __init__(self, use_amp=True):
            self.use_amp = use_amp
            self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
            self.scaler = amp.GradScaler()
    
        def backward(self, loss):
            if self.use_amp:
                return self.scaler.scale(loss).backward()
            else:
                return loss.backward()
    
        def step(self, optimizer):
            if self.use_amp:
                self.scaler.step(optimizer)
                self.scaler.update()
            else:
                optimizer.step()
    
        def might_unscale_(self, optimizer):
            if self.use_amp:
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(optimizer)