diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 13d54b7f2e9141e462ca557b7a5d2802f0c294e4..b378fd2e79486c0c9936f887a307ca1746e4e1aa 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -5,8 +5,13 @@ import zipfile from convlab2 import DATA_ROOT +def get_subdir(subtask): + subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' + return subdir + + def prepare_data(subtask, split, data_root=DATA_ROOT): - data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') + data_dir = os.path.join(data_root, get_subdir(subtask)) zip_filename = os.path.join(data_dir, f'{split}.json.zip') test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) data = {} @@ -131,8 +136,3 @@ def eval_states(gt, pred, subtask): 'f1': f1, } } - - -def get_subdir(subtask): - subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' - return subdir