diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index a89c2e3e8795608b1dfec26652b4d8d2c9ca5800..4b3e1b26106c318bd139a401d2fbeeba13bbe481 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -2,9 +2,8 @@ evaluate output file """ -import os import json -from copy import deepcopy +import os from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index b378fd2e79486c0c9936f887a307ca1746e4e1aa..2938f36f69e95434c4c51698bb6acd59e82dd61c 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,6 +1,7 @@ import os import json import zipfile +from copy import deepcopy from convlab2 import DATA_ROOT @@ -41,12 +42,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] state = {} - for domain_name, domain in turns[i + 1]['sys_state_init'].items(): - domain_state = {} - for slot_name, value in domain.items(): - if slot_name == 'selectedResults': - continue - domain_state[slot_name] = value + for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): + selected_results = domain_state.pop('selectedResults') + if selected_results and 'name' in domain_state and not domain_state['name']: + domain_state['name'] = selected_results state[domain_name] = domain_state dialog_data.append((sys_utt, user_utt, state)) data[dialog_id] = dialog_data @@ -62,20 +61,28 @@ def extract_gt(test_data): return gt -def eval_states(gt, pred, subtask): - # for unifying values with the same meaning to the same expression - value_unifier = { - 'multiwoz': { +# for unifying values with the same meaning to the same expression +def unify_value(value, subtask): + if isinstance(value, list): + ret = deepcopy(value) + for i, v in enumerate(ret): + ret[i] = unify_value(v, subtask) + return ret + return { + 'multiwoz': { + '未提及': '', + 'none': '', + '是的': '有', + '不是': '没有', }, 'crosswoz': { - '未提及': '', + 'None': '', } - }[subtask] + }[subtask].get(value, value) - def unify_value(value): - return value_unifier.get(value, value) +def eval_states(gt, pred, subtask): def exception(description, **kargs): ret = { 'status': 'exception', @@ -107,10 +114,10 @@ def eval_states(gt, pred, subtask): for slot_name, gt_value in gt_domain.items(): if slot_name not in pred_domain: return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) - gt_value = unify_value(gt_value) - pred_value = unify_value(pred_domain[slot_name]) + gt_value = unify_value(gt_value, subtask) + pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value: + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1