diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index 01b353f20890657236eef8c9faf8f4fa46054fb3..bc760d7e6fafb4a7db7e90bf4e364b04dd14c77a 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -3,7 +3,10 @@ import json from tqdm import tqdm from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data -def create_rg_data(data_by_split, data_dir): +def create_rg_data(dataset, data_dir): + data_by_split = load_rg_data(dataset) + os.makedirs(data_dir, exist_ok=True) + data_splits = data_by_split.keys() file_name = os.path.join(data_dir, f"source_prefix.txt") with open(file_name, "w") as f: @@ -30,7 +33,5 @@ if __name__ == '__main__': for dataset_name in tqdm(args.datasets, desc='datasets'): dataset = load_dataset(dataset_name) for task_name in tqdm(args.tasks, desc='tasks', leave=False): - data_by_split = eval(f"load_{task_name}_data")(dataset) data_dir = os.path.join(args.save_dir, task_name, dataset_name) - os.makedirs(data_dir, exist_ok=True) - eval(f"create_{task_name}_data")(data_by_split, data_dir) \ No newline at end of file + eval(f"create_{task_name}_data")(dataset, data_dir) diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index 0679e36759904061f0cee8a9900bc9bf05d5e8d3..a7f4a2f47ce87804af3af6cf234dbcb196570e0e 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -47,9 +47,9 @@ from transformers.utils.versions import require_version # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -# check_min_version("4.16.0.dev0") +check_min_version("4.12.5") -# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") +require_version("datasets>=1.16.1") logger = logging.getLogger(__name__) os.environ["WANDB_DISABLED"] = "true"