diff --git a/convlab/base_models/t5/__init__.py b/convlab/base_models/t5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/convlab/base_models/t5/run_seq2seq.py b/convlab/base_models/t5/run_seq2seq.py index 8ce0f8d7f305b13b10ff0f5b899094fc2a4c96df..7aac3c70746e877469fc34892cd3f93f9fd01f22 100644 --- a/convlab/base_models/t5/run_seq2seq.py +++ b/convlab/base_models/t5/run_seq2seq.py @@ -39,14 +39,13 @@ from transformers import ( AutoTokenizer, DataCollatorForSeq2Seq, HfArgumentParser, - Seq2SeqTrainer, - Seq2SeqTrainingArguments, EarlyStoppingCallback, set_seed, ) from transformers.trainer_utils import EvalPrediction, get_last_checkpoint from transformers.utils import check_min_version from transformers.utils.versions import require_version +from convlab.base_models.t5.trainer import ConvLabSeq2SeqTrainer, ConvLabSeq2SeqTrainingArguments # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -249,7 +248,7 @@ def main(): # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConvLabSeq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. @@ -556,7 +555,7 @@ def main(): training_args.generation_max_length = data_args.val_max_target_length # Initialize our Trainer - trainer = Seq2SeqTrainer( + trainer = ConvLabSeq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..77f125575e0ebfb44e05ef16e4d8d041e016cc81 --- /dev/null +++ b/convlab/base_models/t5/trainer.py @@ -0,0 +1,132 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from email.policy import default +from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass, field +import torch +from torch import nn +from torch.utils.data import Dataset + +from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.trainer_utils import PredictionOutput +from transformers.utils import logging, add_start_docstrings +from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments + + +logger = logging.get_logger(__name__) + +@dataclass +class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments): + """ + `ConvLabSeq2SeqTrainingArguments` is a subclass of `Seq2SeqTrainingArguments` that adds the + following arguments: `do_sample`, `temperature`, `top_k`, `top_p`, `repetition_penalty`, and + `num_return_sequences` + """ + do_sample: bool = field(default=False, metadata={"help": "Whether or not to use sampling ; use greedy decoding otherwise."}) + temperature: Optional[float] = field(default=1.0, metadata={"help": "The value used to module the next token probabilities."}) + top_k: Optional[int] = field(default=0, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k-filtering."}) + top_p: Optional[float] = field(default=1.0, metadata={"help": "If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation."}) + num_return_sequences: Optional[int] = field(default=1, metadata={"help": "The number of independently computed returned sequences for each element in the batch."}) + + + +class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + Subclass and override to inject custom behavior. + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + Return: + Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and + labels (each being optional). + """ + + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + # XXX: adapt synced_gpus for fairscale as well + gen_kwargs = { + "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, + "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, + "synced_gpus": True if is_deepspeed_zero3_enabled() else False, + "do_sample": self.args.do_sample, + "temperature": self.args.temperature, + "top_k": self.args.top_k, + "top_p": self.args.top_p, + "num_return_sequences": self.args.num_return_sequences + } + + if "attention_mask" in inputs: + gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) + if "global_attention_mask" in inputs: + gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) + + # prepare generation inputs + # some encoder-decoder models can have varying encoder's and thus + # varying model input names + if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: + generation_inputs = inputs[self.model.encoder.main_input_name] + else: + generation_inputs = inputs[self.model.main_input_name] + + generated_tokens = self.model.generate( + generation_inputs, + **gen_kwargs, + ) + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + + with torch.no_grad(): + with self.autocast_smart_context_manager(): + outputs = model(**inputs) + if has_labels: + if self.label_smoother is not None: + loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() + else: + loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() + else: + loss = None + + if self.args.prediction_loss_only: + return (loss, None, None) + + if has_labels: + labels = inputs["labels"] + if labels.shape[-1] < gen_kwargs["max_length"]: + labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + else: + labels = None + + return (loss, generated_tokens, labels)