From 0d45aa65d4cae704972569a801a4375e42075f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B4=9A=E9=AA=81?= <function2@qq.com> Date: Thu, 1 Oct 2020 13:31:34 +0800 Subject: [PATCH] fix XLDST evaluation (#141) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * update precision, recall, f1 calculation * minor change * fix a database typo * use selectedResults for missing name * add value unification --- convlab2/dst/dstc9/eval_file.py | 3 +-- convlab2/dst/dstc9/utils.py | 41 +++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py index a89c2e3..4b3e1b2 100644 --- a/convlab2/dst/dstc9/eval_file.py +++ b/convlab2/dst/dstc9/eval_file.py @@ -2,9 +2,8 @@ evaluate output file """ -import os import json -from copy import deepcopy +import os from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index b378fd2..2938f36 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,6 +1,7 @@ import os import json import zipfile +from copy import deepcopy from convlab2 import DATA_ROOT @@ -41,12 +42,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT): sys_utt = turns[i - 1]['content'] if i else None user_utt = turns[i]['content'] state = {} - for domain_name, domain in turns[i + 1]['sys_state_init'].items(): - domain_state = {} - for slot_name, value in domain.items(): - if slot_name == 'selectedResults': - continue - domain_state[slot_name] = value + for domain_name, domain_state in turns[i + 1]['sys_state_init'].items(): + selected_results = domain_state.pop('selectedResults') + if selected_results and 'name' in domain_state and not domain_state['name']: + domain_state['name'] = selected_results state[domain_name] = domain_state dialog_data.append((sys_utt, user_utt, state)) data[dialog_id] = dialog_data @@ -62,20 +61,28 @@ def extract_gt(test_data): return gt -def eval_states(gt, pred, subtask): - # for unifying values with the same meaning to the same expression - value_unifier = { - 'multiwoz': { +# for unifying values with the same meaning to the same expression +def unify_value(value, subtask): + if isinstance(value, list): + ret = deepcopy(value) + for i, v in enumerate(ret): + ret[i] = unify_value(v, subtask) + return ret + return { + 'multiwoz': { + '未提及': '', + 'none': '', + '是的': '有', + '不是': '没有', }, 'crosswoz': { - '未提及': '', + 'None': '', } - }[subtask] + }[subtask].get(value, value) - def unify_value(value): - return value_unifier.get(value, value) +def eval_states(gt, pred, subtask): def exception(description, **kargs): ret = { 'status': 'exception', @@ -107,10 +114,10 @@ def eval_states(gt, pred, subtask): for slot_name, gt_value in gt_domain.items(): if slot_name not in pred_domain: return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) - gt_value = unify_value(gt_value) - pred_value = unify_value(pred_domain[slot_name]) + gt_value = unify_value(gt_value, subtask) + pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value: + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1 -- GitLab