diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py index 80b0bf2e3b3ec3121afeb71dbee2a4b21763cc44..ba0bce934b9f173aecc797119d4f28c0c974251b 100644 --- a/convlab/base_models/t5/trainer.py +++ b/convlab/base_models/t5/trainer.py @@ -32,9 +32,6 @@ from transformers.training_args import ( is_torch_tpu_available, is_sagemaker_mp_enabled, is_sagemaker_dp_enabled, - dist, - xm, - smp ) from transformers.trainer import ( @@ -45,7 +42,7 @@ from transformers.trainer import ( from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from datetime import timedelta - +import torch.distributed as dist logger = logging.get_logger(__name__)