diff --git a/convlab/base_models/t5/__init__.py b/convlab/base_models/t5/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/convlab/base_models/t5/run_seq2seq.py b/convlab/base_models/t5/run_seq2seq.py
index 8ce0f8d7f305b13b10ff0f5b899094fc2a4c96df..7aac3c70746e877469fc34892cd3f93f9fd01f22 100644
--- a/convlab/base_models/t5/run_seq2seq.py
+++ b/convlab/base_models/t5/run_seq2seq.py
@@ -39,14 +39,13 @@ from transformers import (
     AutoTokenizer,
     DataCollatorForSeq2Seq,
     HfArgumentParser,
-    Seq2SeqTrainer,
-    Seq2SeqTrainingArguments,
     EarlyStoppingCallback,
     set_seed,
 )
 from transformers.trainer_utils import EvalPrediction, get_last_checkpoint
 from transformers.utils import check_min_version
 from transformers.utils.versions import require_version
+from convlab.base_models.t5.trainer import ConvLabSeq2SeqTrainer, ConvLabSeq2SeqTrainingArguments
 
 
 # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@@ -249,7 +248,7 @@ def main():
     # or by passing the --help flag to this script.
     # We now keep distinct sets of args, for a cleaner separation of concerns.
 
-    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
+    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ConvLabSeq2SeqTrainingArguments))
     if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
         # If we pass only one argument to the script and it's the path to a json file,
         # let's parse it to get our arguments.
@@ -556,7 +555,7 @@ def main():
         training_args.generation_max_length = data_args.val_max_target_length
 
     # Initialize our Trainer
-    trainer = Seq2SeqTrainer(
+    trainer = ConvLabSeq2SeqTrainer(
         model=model,
         args=training_args,
         train_dataset=train_dataset if training_args.do_train else None,
diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f125575e0ebfb44e05ef16e4d8d041e016cc81
--- /dev/null
+++ b/convlab/base_models/t5/trainer.py
@@ -0,0 +1,132 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from email.policy import default
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass, field
+import torch
+from torch import nn
+from torch.utils.data import Dataset
+
+from transformers.deepspeed import is_deepspeed_zero3_enabled
+from transformers.trainer_utils import PredictionOutput
+from transformers.utils import logging, add_start_docstrings
+from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+
+
+logger = logging.get_logger(__name__)
+
+@dataclass
+class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
+    """
+    `ConvLabSeq2SeqTrainingArguments` is a subclass of `Seq2SeqTrainingArguments` that adds the
+    following arguments: `do_sample`, `temperature`, `top_k`, `top_p`, `repetition_penalty`, and
+    `num_return_sequences`
+    """
+    do_sample: bool = field(default=False, metadata={"help": "Whether or not to use sampling ; use greedy decoding otherwise."})
+    temperature: Optional[float] = field(default=1.0, metadata={"help": "The value used to module the next token probabilities."})
+    top_k: Optional[int] = field(default=0, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k-filtering."})
+    top_p: Optional[float] = field(default=1.0, metadata={"help": "If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation."})
+    num_return_sequences: Optional[int] = field(default=1, metadata={"help": "The number of independently computed returned sequences for each element in the batch."})
+
+
+
+class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
+    def prediction_step(
+        self,
+        model: nn.Module,
+        inputs: Dict[str, Union[torch.Tensor, Any]],
+        prediction_loss_only: bool,
+        ignore_keys: Optional[List[str]] = None,
+    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+        """
+        Perform an evaluation step on `model` using `inputs`.
+        Subclass and override to inject custom behavior.
+        Args:
+            model (`nn.Module`):
+                The model to evaluate.
+            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+            prediction_loss_only (`bool`):
+                Whether or not to return the loss only.
+        Return:
+            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
+            labels (each being optional).
+        """
+
+        if not self.args.predict_with_generate or prediction_loss_only:
+            return super().prediction_step(
+                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+            )
+
+        has_labels = "labels" in inputs
+        inputs = self._prepare_inputs(inputs)
+
+        # XXX: adapt synced_gpus for fairscale as well
+        gen_kwargs = {
+            "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
+            "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
+            "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
+            "do_sample": self.args.do_sample,
+            "temperature": self.args.temperature,
+            "top_k": self.args.top_k,
+            "top_p": self.args.top_p,
+            "num_return_sequences": self.args.num_return_sequences
+        }
+
+        if "attention_mask" in inputs:
+            gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
+        if "global_attention_mask" in inputs:
+            gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
+
+        # prepare generation inputs
+        # some encoder-decoder models can have varying encoder's and thus
+        # varying model input names
+        if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
+            generation_inputs = inputs[self.model.encoder.main_input_name]
+        else:
+            generation_inputs = inputs[self.model.main_input_name]
+
+        generated_tokens = self.model.generate(
+            generation_inputs,
+            **gen_kwargs,
+        )
+        # in case the batch is shorter than max length, the output should be padded
+        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
+            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
+
+        with torch.no_grad():
+            with self.autocast_smart_context_manager():
+                outputs = model(**inputs)
+            if has_labels:
+                if self.label_smoother is not None:
+                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
+                else:
+                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
+            else:
+                loss = None
+
+        if self.args.prediction_loss_only:
+            return (loss, None, None)
+
+        if has_labels:
+            labels = inputs["labels"]
+            if labels.shape[-1] < gen_kwargs["max_length"]:
+                labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
+        else:
+            labels = None
+
+        return (loss, generated_tokens, labels)