From d0ec9473d25c44f34de45f770437e4ec1bffcfb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B4=9A=E9=AA=81?= <function2@qq.com> Date: Tue, 22 Sep 2020 19:08:34 +0800 Subject: [PATCH] revise xldst evaluation (#124) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example --- convlab2/dst/dstc9/.gitignore | 1 + convlab2/dst/dstc9/__init__.py | 2 + convlab2/dst/dstc9/eval_file.py | 54 ++++++++++++------- convlab2/dst/dstc9/eval_model.py | 30 +++++------ convlab2/dst/dstc9/example/.gitignore | 1 - .../dst/dstc9/example/multiwoz_zh/model.py | 4 +- convlab2/dst/dstc9/utils.py | 31 ++++++----- 7 files changed, 74 insertions(+), 49 deletions(-) create mode 100644 convlab2/dst/dstc9/.gitignore delete mode 100644 convlab2/dst/dstc9/example/.gitignore diff --git a/convlab2/dst/dstc9/.gitignore b/convlab2/dst/dstc9/.gitignore new file mode 100644 index 0000000..5e8e7e0 --- /dev/null +++ b/convlab2/dst/dstc9/.gitignore @@ -0,0 +1 @@ +**/*.json diff --git a/convlab2/dst/dstc9/__init__.py b/convlab2/dst/dstc9/__init__.py index e69de29..1abf4a7 100644 --- a/convlab2/dst/dstc9/__init__.py +++ b/convlab2/dst/dstc9/__init__.py @@ -0,0 +1,2 @@ +from .eval_file import evaluate as eval_file +from .eval_model import evaluate as eval_model diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index fa3187d..b53b46a 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -2,28 +2,46 @@ evaluate output file """ -from convlab2.dst.dstc9.utils import prepare_data, eval_states +import os +import json -if __name__ == '__main__': - import os - import json - from argparse import ArgumentParser - parser = ArgumentParser() - parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - args = parser.parse_args() +from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir - gt = { - dialog_id: [state for _, _, state in turns] - for dialog_id, turns in prepare_data(args.subtask).items() - } - # json.dump(gt, open('gt-crosswoz.json', 'w'), ensure_ascii=False, indent=4) +def evaluate(model_dir, subtask, gt): + subdir = get_subdir(subtask) results = {} for i in range(1, 6): - filename = f'submission{i}.json' - if not os.path.exists(filename): + filepath = os.path.join(model_dir, subdir, f'submission{i}.json') + if not os.path.exists(filepath): continue - pred = json.load(open(filename)) - results[filename] = eval_states(gt, pred) + pred = json.load(open(filepath)) + results[i] = eval_states(gt, pred) + + json.dump(results, open(os.path.join(model_dir, subdir, 'file-results.json'), 'w'), indent=4, ensure_ascii=False) + - json.dump(results, open('results.json', 'w'), indent=4, ensure_ascii=False) +def dump_example(subtask, split): + test_data = prepare_data(subtask, split) + gt = extract_gt(test_data) + json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission1.json'), 'w'), ensure_ascii=False, indent=4) + for dialog_id, states in gt.items(): + for state in states: + for domain in state.values(): + for slot in domain: + domain[slot] = "" + json.dump(gt, open(os.path.join('example', get_subdir(subtask), 'submission2.json'), 'w'), ensure_ascii=False, indent=4) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) + parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + args = parser.parse_args() + subtask = args.subtask + split = args.split + dump_example(subtask, split) + test_data = prepare_data(subtask, split) + gt = extract_gt(test_data) + evaluate('example', subtask, gt) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 9b86d05..1b84ea0 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -7,36 +7,36 @@ import json import importlib from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states +from convlab2.dst.dstc9.utils import prepare_data, eval_states, get_subdir -def evaluate(model_name, subtask): - subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' - module = importlib.import_module(f'{model_name}.{subdir}') +def evaluate(model_dir, subtask, test_data, gt): + subdir = get_subdir(subtask) + module = importlib.import_module(f'{model_dir}.{subdir}') assert 'Model' in dir(module), 'please import your model as name `Model` in your subtask module root' model_cls = module.__getattribute__('Model') assert issubclass(model_cls, DST), 'the model must implement DST interface' # load weights, set eval() on default model = model_cls() - gt = {} pred = {} - for dialog_id, turns in prepare_data(subtask).items(): - gt_dialog = [] - pred_dialog = [] + for dialog_id, turns in test_data.items(): model.init_session() - for sys_utt, user_utt, gt_turn in turns: - gt_dialog.append(gt_turn) - pred_dialog.append(model.update_turn(sys_utt, user_utt)) - gt[dialog_id] = gt_dialog - pred[dialog_id] = pred_dialog + pred[dialog_id] = [model.update_turn(sys_utt, user_utt) for sys_utt, user_utt, gt_turn in turns] result = eval_states(gt, pred) print(result) - json.dump(result, open(os.path.join(model_name, subdir, 'result.json'), 'w'), indent=4, ensure_ascii=False) + json.dump(result, open(os.path.join(model_dir, subdir, 'model-result.json'), 'w'), indent=4, ensure_ascii=False) if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) + parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) args = parser.parse_args() - evaluate('example', args.subtask) + subtask = args.subtask + test_data = prepare_data(subtask, args.split) + gt = { + dialog_id: [state for _, _, state in turns] + for dialog_id, turns in test_data.items() + } + evaluate('example', subtask, test_data, gt) diff --git a/convlab2/dst/dstc9/example/.gitignore b/convlab2/dst/dstc9/example/.gitignore deleted file mode 100644 index 4b544e7..0000000 --- a/convlab2/dst/dstc9/example/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*/result.json diff --git a/convlab2/dst/dstc9/example/multiwoz_zh/model.py b/convlab2/dst/dstc9/example/multiwoz_zh/model.py index 1a81cdb..cd6ef36 100644 --- a/convlab2/dst/dstc9/example/multiwoz_zh/model.py +++ b/convlab2/dst/dstc9/example/multiwoz_zh/model.py @@ -51,8 +51,8 @@ class ExampleModel(DST): "出发时间": "", "目的地": "", "日期": "", - "到达时间": "未提及", - "出发地": "未提及", + "到达时间": "", + "出发地": "", }, } diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 75b6077..80cd25d 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -2,21 +2,13 @@ import os import json import zipfile +from convlab2 import DATA_ROOT -def load_test_data(subtask): - from convlab2 import DATA_ROOT - data_dir = os.path.join(DATA_ROOT, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') - # test public data currently - # to check if this script works properly with your code when label information is - # not available, you may need to fill the missing fields yourself (with any value) - zip_filename = os.path.join(data_dir, 'dstc9-test-250.zip') - test_data = json.load(zipfile.ZipFile(zip_filename).open('data.json')) - assert len(test_data) == 250 - return test_data - -def prepare_data(subtask): - test_data = load_test_data(subtask) +def prepare_data(subtask, split, data_root=DATA_ROOT): + data_dir = os.path.join(data_root, 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en') + zip_filename = os.path.join(data_dir, f'{split}.json.zip') + test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) data = {} if subtask == 'multiwoz': for dialog_id, dialog in test_data.items(): @@ -57,6 +49,14 @@ def prepare_data(subtask): return data +def extract_gt(test_data): + gt = { + dialog_id: [state for _, _, state in turns] + for dialog_id, turns in test_data.items() + } + return gt + + def eval_states(gt, pred): def exception(description, **kargs): ret = { @@ -116,3 +116,8 @@ def eval_states(gt, pred): # 'f1': f1, # } } + + +def get_subdir(subtask): + subdir = 'multiwoz_zh' if subtask == 'multiwoz' else 'crosswoz_en' + return subdir -- GitLab