diff --git a/convlab2/dst/dstc9/.gitignore b/convlab2/dst/dstc9/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5e8e7e0b6ae3d59b597fbd4dd6561366bb76163b --- /dev/null +++ b/convlab2/dst/dstc9/.gitignore @@ -0,0 +1 @@ +**/*.json diff --git a/convlab2/dst/dstc9/__init__.py b/convlab2/dst/dstc9/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..1abf4a7ec9b05093db1ded5e3400b684d5585c32 100644 --- a/convlab2/dst/dstc9/__init__.py +++ b/convlab2/dst/dstc9/__init__.py @@ -0,0 +1,2 @@ +from .eval_file import evaluate as eval_file +from .eval_model import evaluate as eval_model diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index fa3187d556bc9997cb1bf130f94afdafbe852ff4..b53b46a2412dd2ddd91a2f93e3dac9b96471b622 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -2,28 +2,46 @@ evaluate output file """ -from convlab2.dst.dstc9.utils import prepare_data, eval_states +import os +import json -if __name__ == '__main__': - import os - import json - from argparse import ArgumentParser - parser = ArgumentParser() - parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - args = parser.parse_args() +from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir - gt = { - dialog_id: [state for _, _, state in turns] - for dialog_id, turns in prepare_data(args.subtask).items() - } - # json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4) +def evaluate(model_dir, subtask, gt): + subdir = get_subdir(subtask) results = {} for i in range(1, 6): - filename = f'submission{i}.json' - if not os.path.exists(filename): + filepath = os.path.join(model_dir, subdir, f'submission{i}.json') + if not os.path.exists(filepath): continue - pred = json.load(open(filename)) - results[filename] = eval_states(gt, pred) + pred = json.load(open(filepath)) + results[i] = eval_states(gt, pred) + + json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) + - json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False) +def dump_example(subtask, split): + test_data = prepare_data(subtask, split) + gt = extract_gt(test_data) + json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) + for dialog_id, states in gt.items(): + for state in states: + for domain in state.values(): + for slot in domain: + domain[slot] = "" + json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) + parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + args = parser.parse_args() + subtask = args.subtask + split = args.split + dump_example(subtask, split) + test_data = prepare_data(subtask, split) + gt = extract_gt(test_data) + evaluate('example', subtask, gt) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 9b86d0577be01c4532688ec9782bc184cb0dd23b..1b84ea05949c00fb537c8fde078943f62707a38a 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -7,36 +7,36 @@ import json import importlib from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states +from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir -def evaluate(model_name, subtask): - subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' - module = importlib.import_module(f'{model_name}.{subdir}') +def evaluate(model_dir, subtask, test_data, gt): + subdir = get_subdir(subtask) + module = importlib.import_module(f'{model_dir}.{subdir}') assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root' model_cls = module.__getattribute__('Model') assert issubclass(model_cls, DST), 'the model must implement DST interface' # load weights, set eval() on default model = model_cls() - gt = {} pred = {} - for dialog_id, turns in prepare_data(subtask).items(): - gt_dialog = [] - pred_dialog = [] + for dialog_id, turns in test_data.items(): model.init_session() - for sys_utt, user_utt, gt_turn in turns: - gt_dialog.append(gt_turn) - pred_dialog.append(model.update_turn(sys_utt, user_utt)) - gt[dialog_id] = gt_dialog - pred[dialog_id] = pred_dialog + pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] result = eval_states(gt, pred) print(result) - json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False) + json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) + parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) args = parser.parse_args() - evaluate('example', args.subtask) + subtask = args.subtask + test_data = prepare_data(subtask, args.split) + gt = { + dialog_id: [state for _, _, state in turns] + for dialog_id, turns in test_data.items() + } + evaluate('example', subtask, test_data, gt) diff --git a/convlab2/dst/dstc9/example/.gitignore b/convlab2/dst/dstc9/example/.gitignore deleted file mode 100644 index 4b544e7771d6b7257bf1ea3aedbf42519654abad..0000000000000000000000000000000000000000 --- a/convlab2/dst/dstc9/example/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*/result.json diff --git a/convlab2/dst/dstc9/example/multiwoz_zh/model.py b/convlab2/dst/dstc9/example/multiwoz_zh/model.py index 1a81cdb34909d5ec4b29428c47281ec8c47db9c9..cd6ef3615052372d8113c037884fe7f60081b4c8 100644 --- a/convlab2/dst/dstc9/example/multiwoz_zh/model.py +++ b/convlab2/dst/dstc9/example/multiwoz_zh/model.py @@ -51,8 +51,8 @@ class ExampleModel(DST): "出发时间": "", "目的地": "", "日期": "", - "到达时间": "未提及", - "出发地": "未提及", + "到达时间": "", + "出发地": "", }, } diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 75b60774d9f076067ba84e922e18875fecf4bc60..80cd25d99a16550a23f14bbae326eee500033777 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -2,21 +2,13 @@ import os import json import zipfile +from convlab2 import DATA_ROOT -def load_test_data(subtask): - from convlab2 import DATA_ROOT - data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') - # test public data currently - # to check if this script works properly with your code when label information is - # not available, you may need to fill the missing fields yourself (with any value) - zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip') - test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json')) - assert len(test_data) == 250 - return test_data - -def prepare_data(subtask): - test_data = load_test_data(subtask) +def prepare_data(subtask, split, data_root=DATA_ROOT): + data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') + zip_filename = os.path.join(data_dir, f'{split}.json.zip') + test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) data = {} if subtask == 'multiwoz': for dialog_id, dialog in test_data.items(): @@ -57,6 +49,14 @@ def prepare_data(subtask): return data +def extract_gt(test_data): + gt = { + dialog_id: [state for _, _, state in turns] + for dialog_id, turns in test_data.items() + } + return gt + + def eval_states(gt, pred): def exception(description, **kargs): ret = { @@ -116,3 +116,8 @@ def eval_states(gt, pred): # 'f1': f1, # } } + + +def get_subdir(subtask): + subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' + return subdir