From 582758df2803c4382db6ad8774cb49facfafcdad Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Tue, 28 Dec 2021 09:32:54 +0000 Subject: [PATCH] add package version check --- convlab2/base_models/t5/create_data.py | 9 +++++---- convlab2/base_models/t5/run_seq2seq.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index 01b353f2..bc760d7e 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -3,7 +3,10 @@ import json from tqdm import tqdm from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data -def create_rg_data(data_by_split, data_dir): +def create_rg_data(dataset, data_dir): + data_by_split = load_rg_data(dataset) + os.makedirs(data_dir, exist_ok=True) + data_splits = data_by_split.keys() file_name = os.path.join(data_dir, f"source_prefix.txt") with open(file_name, "w") as f: @@ -30,7 +33,5 @@ if __name__ == '__main__': for dataset_name in tqdm(args.datasets, desc='datasets'): dataset = load_dataset(dataset_name) for task_name in tqdm(args.tasks, desc='tasks', leave=False): - data_by_split = eval(f"load_{task_name}_data")(dataset) data_dir = os.path.join(args.save_dir, task_name, dataset_name) - os.makedirs(data_dir, exist_ok=True) - eval(f"create_{task_name}_data")(data_by_split, data_dir) \ No newline at end of file + eval(f"create_{task_name}_data")(dataset, data_dir) diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index 0679e367..a7f4a2f4 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -47,9 +47,9 @@ from transformers.utils.versions import require_version # Will error if the minimal version of Transformers is not installed. Remove at your own risks. -# check_min_version("4.16.0.dev0") +check_min_version("4.12.5") -# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") +require_version("datasets>=1.16.1") logger = logging.getLogger(__name__) os.environ["WANDB_DISABLED"] = "true" -- GitLab