diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index 30e2b1d5e840292b030a626eacc8ac3628b89f5f..0679e36759904061f0cee8a9900bc9bf05d5e8d3 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -1,4 +1,4 @@ -#!python +#!/usr/bin/env python # coding=utf-8 # Copyright 2021 The HuggingFace Team. All rights reserved. # @@ -27,12 +27,10 @@ from dataclasses import dataclass, field from typing import Optional import datasets -import nltk # Here to have a nice missing dependency error message early on import numpy as np -from datasets import load_dataset, load_metric +from datasets import load_dataset import transformers -from filelock import FileLock from transformers import ( AutoConfig, AutoModelForSeq2SeqLM, @@ -43,7 +41,6 @@ from transformers import ( Seq2SeqTrainingArguments, set_seed, ) -from transformers.file_utils import is_offline_mode from transformers.trainer_utils import EvalPrediction, get_last_checkpoint from transformers.utils import check_min_version from transformers.utils.versions import require_version @@ -109,7 +106,7 @@ class DataTrainingArguments: task_name: Optional[str] = field( - default=None, metadata={"help": "The name of the task, e.g., response generation."} + default=None, metadata={"help": "The name of the task, e.g., rg (for rgresponse generation)."} ) dataset_name: Optional[str] = field( default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} @@ -223,7 +220,7 @@ class DataTrainingArguments: and self.validation_file is None and self.test_file is None ): - raise ValueError("Need either a dataset name or a training/validation file/test_file.") + raise ValueError("Need either a dataset name or a training/validation/testing file.") else: if self.train_file is not None: extension = self.train_file.split(".")[-1] @@ -238,8 +235,6 @@ class DataTrainingArguments: self.val_max_target_length = self.max_target_length - - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -586,6 +581,7 @@ def main(): trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) + # Predict if training_args.do_predict: logger.info("*** Predict ***") @@ -636,4 +632,4 @@ def _mp_fn(index): if __name__ == "__main__": - main() \ No newline at end of file + main()