diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index fe33b0244895a66f464dfe73822d6560c984bbd7..38b75772b659857c9184291c559939caa5456c9a 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -51,10 +51,11 @@ if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val', 'dstc9-250']) + parser.add_argument('correct_name_label', action='store_true') args = parser.parse_args() subtask = args.subtask split = args.split dump_example(subtask, split) - test_data = prepare_data(subtask, split) + test_data = prepare_data(subtask, split, correct_name_label=args.correct_name_label) gt = extract_gt(test_data) diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 08761878519ba2a5727c8b6ac50297c6398dda4d..45441284e84c7c0059424efa984fab212fa1acc4 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -34,9 +34,9 @@ def evaluate(model_dir, subtask, test_data, gt): dump_result(model_dir, 'model-result.json', result, errors, pred) -def eval_team(team): +def eval_team(team, correct_name_label): for subtask in ['multiwoz', 'crosswoz']: - test_data = prepare_data(subtask, 'dstc9') + test_data = prepare_data(subtask, 'dstc9', correct_name_label=correct_name_label) gt = extract_gt(test_data) for i in range(1, 6): model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}') @@ -50,12 +50,13 @@ if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--teams', type=str, nargs='*') + parser.add_argument('correct_name_label', action='store_true') args = parser.parse_args() if not args.teams: for team in os.listdir('.'): if not os.path.isdir(team): continue - eval_team(team) + eval_team(team, args.correct_name_label) else: for team in args.teams: - eval_team(team) + eval_team(team, args.correct_name_label) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 06615fbc296c9215666ea10075a0ab6356105d16..3b706c5d25d6ea548b4a398fdfb65d2ed2834ef9 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -10,7 +10,7 @@ def get_subdir(subtask): return subdir -def prepare_data(subtask, split, data_root=DATA_ROOT): +def prepare_data(subtask, split, data_root=DATA_ROOT, correct_name_label=False): data_dir = os.path.join(data_root, get_subdir(subtask)) zip_filename = os.path.join(data_dir, f'{split}.json.zip') test_data = json.load(zipfile.ZipFile(zip_filename).open(f'{split}.json')) @@ -37,25 +37,31 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for dialog_id, dialog in test_data.items(): dialog_data = [] turns = dialog['messages'] - selected_results = {k: [] for k in turns[1]['sys_state'].keys()} + if correct_name_label: + selected_results = {domain_name: [] for domain_name in turns[1]['sys_state_init']} for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] dialog_state = {} for domain_name, state in turns[i + 1]['sys_state_init'].items(): - state.pop('selectedResults') - sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') - # if state has changed compared to previous sys state - state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] - # clear the outdated previous selected results if state has been updated - if state_change: - selected_results[domain_name].clear() - if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: - state['name'] = selected_results[domain_name][0] - dialog_state[domain_name] = state - if state_change and sys_selected_results: - selected_results[domain_name] = sys_selected_results - + if correct_name_label: + state.pop('selectedResults') + sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') + # if state has changed compared to previous sys state + state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] + # clear the outdated previous selected results if state has been updated + if state_change: + selected_results[domain_name].clear() + if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: + state['name'] = selected_results[domain_name][0] + dialog_state[domain_name] = state + if state_change and sys_selected_results: + selected_results[domain_name] = sys_selected_results + else: + selected_results = state.pop('selectedResults') + if selected_results and 'name' in state and not state['name']: + state['name'] = selected_results + dialog_state[domain_name] = state dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data