diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py
index ba0bce934b9f173aecc797119d4f28c0c974251b..dd8c2dfccec649359b0c03fe6b47dcf57bb01b22 100644
--- a/convlab/base_models/t5/trainer.py
+++ b/convlab/base_models/t5/trainer.py
@@ -16,9 +16,11 @@
 # from dataclasses import dataclass, field
 # import torch
 # from torch import nn
+from torch.utils.data import Dataset
 
 # from transformers.deepspeed import is_deepspeed_zero3_enabled
 # from transformers.utils import logging, cached_property, torch_required
+from transformers.trainer_utils import PredictionOutput
 from transformers.training_args import (
     os, 
     torch,
@@ -161,6 +163,103 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
 
 
 class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
+    # modifed from Seq2SeqTrainer of 4.26.1: https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/trainer_seq2seq.py
+    # add generation args in `prediction_step`
+    def evaluate(
+        self,
+        eval_dataset: Optional[Dataset] = None,
+        ignore_keys: Optional[List[str]] = None,
+        metric_key_prefix: str = "eval",
+        **gen_kwargs
+    ) -> Dict[str, float]:
+        """
+        Run evaluation and returns metrics.
+        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+        (pass it to the init `compute_metrics` argument).
+        You can also subclass and override this method to inject custom behavior.
+        Args:
+            eval_dataset (`Dataset`, *optional*):
+                Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
+                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+                method.
+            ignore_keys (`List[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "eval_bleu" if the prefix is `"eval"` (default)
+            max_length (`int`, *optional*):
+                The maximum target length to use when predicting with the generate method.
+            num_beams (`int`, *optional*):
+                Number of beams for beam search that will be used when predicting with the generate method. 1 means no
+                beam search.
+            gen_kwargs:
+                Additional `generate` specific kwargs.
+        Returns:
+            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+            dictionary also contains the epoch number which comes from the training state.
+        """
+
+        gen_kwargs = gen_kwargs.copy()
+        if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
+            gen_kwargs["max_length"] = self.args.generation_max_length
+        gen_kwargs["num_beams"] = (
+            gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
+        )
+        self._gen_kwargs = gen_kwargs
+
+        return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
+
+    def predict(
+        self,
+        test_dataset: Dataset,
+        ignore_keys: Optional[List[str]] = None,
+        metric_key_prefix: str = "test",
+        **gen_kwargs
+    ) -> PredictionOutput:
+        """
+        Run prediction and returns predictions and potential metrics.
+        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+        will also return metrics, like in `evaluate()`.
+        Args:
+            test_dataset (`Dataset`):
+                Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
+                `model.forward()` method are automatically removed. Has to implement the method `__len__`
+            ignore_keys (`List[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "eval_bleu" if the prefix is `"eval"` (default)
+            max_length (`int`, *optional*):
+                The maximum target length to use when predicting with the generate method.
+            num_beams (`int`, *optional*):
+                Number of beams for beam search that will be used when predicting with the generate method. 1 means no
+                beam search.
+            gen_kwargs:
+                Additional `generate` specific kwargs.
+        <Tip>
+        If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
+        padding in a token classification task) the predictions will be padded (on the right) to allow for
+        concatenation into one array. The padding index is -100.
+        </Tip>
+        Returns: *NamedTuple* A namedtuple with the following keys:
+            - predictions (`np.ndarray`): The predictions on `test_dataset`.
+            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+            - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+              labels).
+        """
+
+        gen_kwargs = gen_kwargs.copy()
+        if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
+            gen_kwargs["max_length"] = self.args.generation_max_length
+        gen_kwargs["num_beams"] = (
+            gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
+        )
+        self._gen_kwargs = gen_kwargs
+
+        return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
+
     def prediction_step(
         self,
         model: nn.Module,
@@ -194,16 +293,25 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
         inputs = self._prepare_inputs(inputs)
 
         # XXX: adapt synced_gpus for fairscale as well
-        gen_kwargs = {
-            "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
-            "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
-            "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
+        gen_kwargs = self._gen_kwargs.copy()
+        if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
+            gen_kwargs["max_length"] = self.model.config.max_length
+        gen_kwargs["num_beams"] = (
+            gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
+        )
+        default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
+        gen_kwargs["synced_gpus"] = (
+            gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
+        )
+
+        # DONE: add generation arguments
+        gen_kwargs.update({
             "do_sample": self.args.do_sample,
             "temperature": self.args.temperature,
             "top_k": self.args.top_k,
             "top_p": self.args.top_p,
             "num_return_sequences": self.args.num_return_sequences
-        }
+        })
 
         if "attention_mask" in inputs:
             gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
@@ -223,13 +331,17 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
             **gen_kwargs,
         )
         # in case the batch is shorter than max length, the output should be padded
-        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
+        if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
             generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
+        elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
+            gen_kwargs["max_new_tokens"] + 1
+        ):
+            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
 
         with torch.no_grad():
-            with self.autocast_smart_context_manager():
-                outputs = model(**inputs)
             if has_labels:
+                with self.compute_loss_context_manager():
+                    outputs = model(**inputs)
                 if self.label_smoother is not None:
                     loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                 else:
@@ -242,8 +354,12 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
 
         if has_labels:
             labels = inputs["labels"]
-            if labels.shape[-1] < gen_kwargs["max_length"]:
+            if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
                 labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
+            elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
+                gen_kwargs["max_new_tokens"] + 1
+            ):
+                labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
         else:
             labels = None