From e80628bb9b716f4a236e87bd3c871e58d965df40 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Thu, 1 Oct 2020 10:59:25 +0800 Subject: [PATCH] add value unification --- convlab2/dst/dstc9/utils.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index c926103..2938f36 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,6 +1,7 @@ import os import json import zipfile +from copy import deepcopy from convlab2 import DATA_ROOT @@ -60,22 +61,28 @@ def extract_gt(test_data): return gt -def eval_states(gt, pred, subtask): - # for unifying values with the same meaning to the same expression - value_unifier = { +# for unifying values with the same meaning to the same expression +def unify_value(value, subtask): + if isinstance(value, list): + ret = deepcopy(value) + for i, v in enumerate(ret): + ret[i] = unify_value(v, subtask) + return ret + + return { 'multiwoz': { '未提及': '', + 'none': '', + '是的': '有', + '不是': '没有', }, 'crosswoz': { - + 'None': '', } - }[subtask] + }[subtask].get(value, value) - def unify_value(value): - if isinstance(value, list): - return list(map(unify_value, value)) - return value_unifier.get(value, value) +def eval_states(gt, pred, subtask): def exception(description, **kargs): ret = { 'status': 'exception', @@ -107,8 +114,8 @@ def eval_states(gt, pred, subtask): for slot_name, gt_value in gt_domain.items(): if slot_name not in pred_domain: return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) - gt_value = unify_value(gt_value) - pred_value = unify_value(pred_domain[slot_name]) + gt_value = unify_value(gt_value, subtask) + pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 -- GitLab