Skip to content
Snippets Groups Projects
Commit 582758df authored by zqwerty's avatar zqwerty
Browse files

add package version check

parent ce2d802d
Branches
Tags
No related merge requests found
......@@ -3,7 +3,10 @@ import json
from tqdm import tqdm
from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
def create_rg_data(data_by_split, data_dir):
def create_rg_data(dataset, data_dir):
data_by_split = load_rg_data(dataset)
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
file_name = os.path.join(data_dir, f"source_prefix.txt")
with open(file_name, "w") as f:
......@@ -30,7 +33,5 @@ if __name__ == '__main__':
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_by_split = eval(f"load_{task_name}_data")(dataset)
data_dir = os.path.join(args.save_dir, task_name, dataset_name)
os.makedirs(data_dir, exist_ok=True)
eval(f"create_{task_name}_data")(data_by_split, data_dir)
\ No newline at end of file
eval(f"create_{task_name}_data")(dataset, data_dir)
......@@ -47,9 +47,9 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.16.0.dev0")
check_min_version("4.12.5")
# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
require_version("datasets>=1.16.1")
logger = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment