Skip to content
Snippets Groups Projects
Commit 582758df authored by zqwerty's avatar zqwerty
Browse files

add package version check

parent ce2d802d
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,10 @@ import json
from tqdm import tqdm
from convlab2.util import load_dataset, load_nlu_data, load_dst_data, load_policy_data, load_nlg_data, load_e2e_data, load_rg_data
def create_rg_data(data_by_split, data_dir):
def create_rg_data(dataset, data_dir):
data_by_split = load_rg_data(dataset)
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
file_name = os.path.join(data_dir, f"source_prefix.txt")
with open(file_name, "w") as f:
......@@ -30,7 +33,5 @@ if __name__ == '__main__':
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_by_split = eval(f"load_{task_name}_data")(dataset)
data_dir = os.path.join(args.save_dir, task_name, dataset_name)
os.makedirs(data_dir, exist_ok=True)
eval(f"create_{task_name}_data")(data_by_split, data_dir)
\ No newline at end of file
eval(f"create_{task_name}_data")(dataset, data_dir)
......@@ -47,9 +47,9 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# check_min_version("4.16.0.dev0")
check_min_version("4.12.5")
# require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
require_version("datasets>=1.16.1")
logger = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment