diff --git a/convlab2/dst/dstc9/eval_model.py b/convlab2/dst/dstc9/eval_model.py index 3a3e0814266947728981d12be38b49ae7b85940b..08761878519ba2a5727c8b6ac50297c6398dda4d 100644 --- a/convlab2/dst/dstc9/eval_model.py +++ b/convlab2/dst/dstc9/eval_model.py @@ -9,7 +9,7 @@ import importlib from tqdm import tqdm from convlab2.dst import DST -from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result +from convlab2.dst.dstc9.utils import prepare_data, eval_states, dump_result, extract_gt def evaluate(model_dir, subtask, test_data, gt): @@ -34,16 +34,28 @@ def evaluate(model_dir, subtask, test_data, gt): dump_result(model_dir, 'model-result.json', result, errors, pred) +def eval_team(team): + for subtask in ['multiwoz', 'crosswoz']: + test_data = prepare_data(subtask, 'dstc9') + gt = extract_gt(test_data) + for i in range(1, 6): + model_dir = os.path.join(team, f'{subtask}-dst', f'submission{i}') + if not os.path.exists(model_dir): + continue + print(model_dir) + evaluate(model_dir, subtask, test_data, gt) + + if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser() - parser.add_argument('subtask', type=str, choices=['multiwoz', 'crosswoz']) - parser.add_argument('split', type=str, choices=['train', 'val', 'test', 'human_val']) + parser.add_argument('--teams', type=str, nargs='*') args = parser.parse_args() - subtask = args.subtask - test_data = prepare_data(subtask, args.split) - gt = { - dialog_id: [state for _, _, state in turns] - for dialog_id, turns in test_data.items() - } - evaluate('example', subtask, test_data, gt) + if not args.teams: + for team in os.listdir('.'): + if not os.path.isdir(team): + continue + eval_team(team) + else: + for team in args.teams: + eval_team(team) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index aee61056dca0d651654b5135da165fc0d20b4d9b..06615fbc296c9215666ea10075a0ab6356105d16 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,7 +1,6 @@ -import os import json +import os import zipfile -from copy import deepcopy from convlab2 import DATA_ROOT @@ -23,31 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['text'] if i else None user_utt = turns[i]['text'] - state = {} + dialog_state = {} for domain_name, domain in turns[i + 1]['metadata'].items(): if domain_name in ['警察机关', '医院', '公共汽车']: continue - domain_state = {} + state = {} for slots in domain.values(): for slot_name, value in slots.items(): - domain_state[slot_name] = value - state[domain_name] = domain_state - dialog_data.append((sys_utt, user_utt, state)) + state[slot_name] = value + dialog_state[domain_name] = state + dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data else: for dialog_id, dialog in test_data.items(): dialog_data = [] turns = dialog['messages'] + selected_results = {k: [] for k in turns[1]['sys_state'].keys()} for i in range(0, len(turns), 2): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] - state = {} - for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): - selected_results = domain_state.pop('selectedResults') - if selected_results and 'name' in domain_state and not domain_state['name']: - domain_state['name'] = selected_results - state[domain_name] = domain_state - dialog_data.append((sys_utt, user_utt, state)) + dialog_state = {} + for domain_name, state in turns[i + 1]['sys_state_init'].items(): + state.pop('selectedResults') + sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults') + # if state has changed compared to previous sys state + state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name] + # clear the outdated previous selected results if state has been updated + if state_change: + selected_results[domain_name].clear() + if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1: + state['name'] = selected_results[domain_name][0] + dialog_state[domain_name] = state + if state_change and sys_selected_results: + selected_results[domain_name] = sys_selected_results + + dialog_data.append((sys_utt, user_utt, dialog_state)) data[dialog_id] = dialog_data return data @@ -63,12 +72,6 @@ def extract_gt(test_data): # for unifying values with the same meaning to the same expression def unify_value(value, subtask): - if isinstance(value, list): - ret = deepcopy(value) - for i, v in enumerate(ret): - ret[i] = unify_value(v, subtask) - return ret - value = value.lower() value = { 'multiwoz': { @@ -79,10 +82,11 @@ def unify_value(value, subtask): }, 'crosswoz': { 'none': '', + 'free admission': 'free', } }[subtask].get(value, value) - return ' '.join(value.strip().split()) + return ''.join(value.strip().split()) def eval_states(gt, pred, subtask): @@ -94,7 +98,7 @@ def eval_states(gt, pred, subtask): for k, v in kargs.items(): ret[k] = v return ret, None - errors = [] + errors = [['dialog id', 'turn id', 'domain name', 'slot name', 'ground truth', 'predict']] joint_acc, joint_tot = 0, 0 slot_acc, slot_tot = 0, 0 @@ -122,12 +126,12 @@ def eval_states(gt, pred, subtask): pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: + if gt_value == pred_value: slot_acc += 1 if gt_value: tp += 1 else: - errors.append([gt_value, pred_value]) + errors.append([dialog_id, turn_id, domain_name, slot_name, gt_value, pred_value]) turn_result = False if gt_value: fn += 1