diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py
index 77f125575e0ebfb44e05ef16e4d8d041e016cc81..80b0bf2e3b3ec3121afeb71dbee2a4b21763cc44 100644
--- a/convlab/base_models/t5/trainer.py
+++ b/convlab/base_models/t5/trainer.py
@@ -12,17 +12,39 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from email.policy import default
-from typing import Any, Dict, List, Optional, Tuple, Union
-from dataclasses import dataclass, field
-import torch
-from torch import nn
-from torch.utils.data import Dataset
-
-from transformers.deepspeed import is_deepspeed_zero3_enabled
-from transformers.trainer_utils import PredictionOutput
-from transformers.utils import logging, add_start_docstrings
+# from typing import Any, Dict, List, Optional, Tuple, Union
+# from dataclasses import dataclass, field
+# import torch
+# from torch import nn
+
+# from transformers.deepspeed import is_deepspeed_zero3_enabled
+# from transformers.utils import logging, cached_property, torch_required
+from transformers.training_args import (
+    os, 
+    torch,
+    logging, 
+    dataclass, 
+    field, 
+    Optional, 
+    cached_property, 
+    torch_required, 
+    get_int_from_env,
+    is_torch_tpu_available,
+    is_sagemaker_mp_enabled,
+    is_sagemaker_dp_enabled,
+    dist,
+    xm,
+    smp
+)
+
+from transformers.trainer import (
+    nn,
+    Any, Dict, List, Tuple, Union,
+    is_deepspeed_zero3_enabled
+)
+
 from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from datetime import timedelta
 
 
 logger = logging.get_logger(__name__)
@@ -40,6 +62,105 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
     top_p: Optional[float] = field(default=1.0, metadata={"help": "If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation."})
     num_return_sequences: Optional[int] = field(default=1, metadata={"help": "The number of independently computed returned sequences for each element in the batch."})
 
+    @cached_property
+    @torch_required
+    def _setup_devices(self) -> "torch.device":
+        logger.info("PyTorch: setting up devices")
+        if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
+            logger.warning(
+                "torch.distributed process group is initialized, but local_rank == -1. "
+                "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
+            )
+        if self.no_cuda:
+            device = torch.device("cpu")
+            self._n_gpu = 0
+            self.local_rank = get_int_from_env(
+                ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"],
+                self.local_rank,
+            )
+            if self.local_rank != -1 and not torch.distributed.is_initialized():
+                # Initializes distributed backend for cpu
+                if self.xpu_backend not in ("mpi", "ccl"):
+                    raise ValueError(
+                        "CPU distributed training backend is not properly set. "
+                        "Please set '--xpu_backend' to either 'mpi' or 'ccl'."
+                    )
+                if self.xpu_backend == "ccl" and int(os.environ.get("CCL_WORKER_COUNT", 0)) < 1:
+                    raise ValueError(
+                        "CPU distributed training backend is ccl. but CCL_WORKER_COUNT is not correctly set. "
+                        "Please use like 'export CCL_WORKER_COUNT = 1' to set."
+                    )
+
+                # Try to get launch configuration from environment variables set by MPI launcher - works for Intel MPI, OpenMPI and MVAPICH
+                rank = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0)
+                size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1)
+                local_size = get_int_from_env(
+                    ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
+                )
+                os.environ["RANK"] = str(rank)
+                os.environ["WORLD_SIZE"] = str(size)
+                os.environ["LOCAL_RANK"] = str(self.local_rank)
+                if not os.environ.get("MASTER_PORT", None):
+                    os.environ["MASTER_PORT"] = "29500"
+                if not os.environ.get("MASTER_ADDR", None):
+                    if local_size != size or self.xpu_backend != "mpi":
+                        raise ValueError(
+                            "Looks like distributed multinode run but MASTER_ADDR env not set, "
+                            "please try exporting rank 0's hostname as MASTER_ADDR"
+                        )
+                torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size, timeout=timedelta(days=365))
+        elif is_torch_tpu_available():
+            device = xm.xla_device()
+            self._n_gpu = 0
+        elif is_sagemaker_mp_enabled():
+            local_rank = smp.local_rank()
+            device = torch.device("cuda", local_rank)
+            self._n_gpu = 1
+        elif is_sagemaker_dp_enabled():
+            dist.init_process_group(backend="smddp", timeout=timedelta(days=365))
+            self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
+            device = torch.device("cuda", self.local_rank)
+            self._n_gpu = 1
+        elif self.deepspeed:
+            # deepspeed inits torch.distributed internally
+            from transformers.deepspeed import is_deepspeed_available
+
+            if not is_deepspeed_available():
+                raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
+            import deepspeed
+
+            deepspeed.init_distributed()
+
+            # workaround for setups like notebooks where the launcher can't be used,
+            # but deepspeed requires a dist env.
+            # env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
+            self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))
+
+            device = torch.device("cuda", self.local_rank)
+            self._n_gpu = 1
+        elif self.local_rank == -1:
+            # if n_gpu is > 1 we'll use nn.DataParallel.
+            # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
+            # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
+            # trigger an error that a device index is missing. Index 0 takes into account the
+            # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
+            # will use the first GPU in that env, i.e. GPU#1
+            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+            # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
+            # the default value.
+            self._n_gpu = torch.cuda.device_count()
+        else:
+            # Here, we'll use torch.distributed.
+            # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
+            if not torch.distributed.is_initialized():
+                torch.distributed.init_process_group(backend="nccl", timeout=timedelta(days=365))
+            device = torch.device("cuda", self.local_rank)
+            self._n_gpu = 1
+
+        if device.type == "cuda":
+            torch.cuda.set_device(device)
+
+        return device
 
 
 class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):