From e0ee065c04a5d496be3402d13209cc5c2da0c59e Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Sun, 1 Nov 2020 17:13:53 +0800 Subject: [PATCH] add correct name labem argument --- convlab2/dst/dstc9/eval_file.py | 5 +++-- convlab2/dst/dstc9/eval_model.py | 9 +++++---- convlab2/dst/dstc9/utils.py | 26 +++++++++++++++++++++----- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index fe33b02..38b7577 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -51,10 +51,11 @@ if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val', 'dstc9-250']) + parser.add_argument('correct_name_label', action='store_true') args = parser.parse_args() subtask = args.subtask split = args.split dump_example(subtask, split) - test_data = prepare_data(subtask, split) + test_data = prepare_data(subtask, split, correct_name_label=args.correct_name_label) gt = extract_gt(test_data) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 0876187..4544128 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -34,9 +34,9 @@ def evaluate(model_dir, subtask, test_data, gt): dump_result(model_dir, 'model-result.json', result, errors, pred) -def eval_team(team): +def eval_team(team, correct_name_label): for subtask in ['multiwoz', 'crosswoz']: - test_data = prepare_data(subtask, 'dstc9') + test_data = prepare_data(subtask, 'dstc9', correct_name_label=correct_name_label) gt = extract_gt(test_data) for i in range(1, 6): model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}') @@ -50,12 +50,13 @@ if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--teams', type=str, nargs='*') + parser.add_argument('correct_name_label', action='store_true') args = parser.parse_args() if not args.teams: for team in os.listdir('.'): if not os.path.isdir(team): continue - eval_team(team) + eval_team(team, args.correct_name_label) else: for team in args.teams: - eval_team(team) + eval_team(team, args.correct_name_label) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 588e90c..29a951a 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -11,7 +11,7 @@ def get_subdir(subtask): return subdir -def prepare_data(subtask, split, data_root=DATA_ROOT): +def prepare_data(subtask, split, data_root=DATA_ROOT, correct_name_label=False): data_dir = os.path.join(data_root, get_subdir(subtask)) zip_filename = os.path.join(data_dir, f'{split}.json.zip') test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) @@ -38,15 +38,31 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for dialog_id, dialog in test_data.items(): dialog_data = [] turns = dialog['messages'] + if correct_name_label: + selected_results = {domain_name: [] for domain_name in turns[1]['sys_state_init']} for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] dialog_state = {} for domain_name, state in turns[i + 1]['sys_state_init'].items(): - selected_results = state.pop('selectedResults') - if selected_results and 'name' in state and not state['name']: - state['name'] = selected_results - dialog_state[domain_name] = state + if correct_name_label: + state.pop('selectedResults') + sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') + # if state has changed compared to previous sys state + state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] + # clear the outdated previous selected results if state has been updated + if state_change: + selected_results[domain_name].clear() + if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: + state['name'] = selected_results[domain_name][0] + dialog_state[domain_name] = state + if state_change and sys_selected_results: + selected_results[domain_name] = sys_selected_results + else: + selected_results = state.pop('selectedResults') + if selected_results and 'name' in state and not state['name']: + state['name'] = selected_results + dialog_state[domain_name] = state dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data -- GitLab