From 582758df2803c4382db6ad8774cb49facfafcdad Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Tue, 28 Dec 2021 09:32:54 +0000
Subject: [PATCH] add package version check

---
 convlab2/base_models/t5/create_data.py | 9 +++++----
 convlab2/base_models/t5/run_seq2seq.py | 4 ++--
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py
index 01b353f2..bc760d7e 100644
--- a/convlab2/base_models/t5/create_data.py
+++ b/convlab2/base_models/t5/create_data.py
@@ -3,7 +3,10 @@ import json
 from tqdm import tqdm
 from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
 
-def create_rg_data(data_by_split, data_dir):
+def create_rg_data(dataset, data_dir):
+    data_by_split = load_rg_data(dataset)
+    os.makedirs(data_dir, exist_ok=True)
+
     data_splits = data_by_split.keys()
     file_name = os.path.join(data_dir, f"source_prefix.txt")
     with open(file_name, "w") as f:
@@ -30,7 +33,5 @@ if __name__ == '__main__':
     for dataset_name in tqdm(args.datasets, desc='datasets'):
         dataset = load_dataset(dataset_name)
         for task_name in tqdm(args.tasks, desc='tasks', leave=False):
-            data_by_split = eval(f"load_{task_name}_data")(dataset)
             data_dir = os.path.join(args.save_dir, task_name, dataset_name)
-            os.makedirs(data_dir, exist_ok=True)
-            eval(f"create_{task_name}_data")(data_by_split, data_dir)
\ No newline at end of file
+            eval(f"create_{task_name}_data")(dataset, data_dir)
diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py
index 0679e367..a7f4a2f4 100644
--- a/convlab2/base_models/t5/run_seq2seq.py
+++ b/convlab2/base_models/t5/run_seq2seq.py
@@ -47,9 +47,9 @@ from transformers.utils.versions import require_version
 
 
 # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-# check_min_version("4.16.0.dev0")
+check_min_version("4.12.5")
 
-# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
+require_version("datasets>=1.16.1")
 
 logger = logging.getLogger(__name__)
 os.environ["WANDB_DISABLED"] = "true"
-- 
GitLab