diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index c926103684d23dc11adef2cb266025b64849ab7f..2938f36f69e95434c4c51698bb6acd59e82dd61c 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -1,6 +1,7 @@ import os import json import zipfile +from copy import deepcopy from convlab2 import DATA_ROOT @@ -60,22 +61,28 @@ def extract_gt(test_data): return gt -def eval_states(gt, pred, subtask): - # for unifying values with the same meaning to the same expression - value_unifier = { +# for unifying values with the same meaning to the same expression +def unify_value(value, subtask): + if isinstance(value, list): + ret = deepcopy(value) + for i, v in enumerate(ret): + ret[i] = unify_value(v, subtask) + return ret + + return { 'multiwoz': { '未提及': '', + 'none': '', + '是的': '有', + '不是': '没有', }, 'crosswoz': { - + 'None': '', } - }[subtask] + }[subtask].get(value, value) - def unify_value(value): - if isinstance(value, list): - return list(map(unify_value, value)) - return value_unifier.get(value, value) +def eval_states(gt, pred, subtask): def exception(description, **kargs): ret = { 'status': 'exception', @@ -107,8 +114,8 @@ def eval_states(gt, pred, subtask): for slot_name, gt_value in gt_domain.items(): if slot_name not in pred_domain: return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name) - gt_value = unify_value(gt_value) - pred_value = unify_value(pred_domain[slot_name]) + gt_value = unify_value(gt_value, subtask) + pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1