diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 3b706c5d25d6ea548b4a398fdfb65d2ed2834ef9..cfa1f4dd628dee46eef0d83e9da02c12fd4335ca 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -78,6 +78,11 @@ def extract_gt(test_data): # for unifying values with the same meaning to the same expression def unify_value(value, subtask): + if isinstance(value, list): + for i, v in enumerate(value): + value[i] = unify_value(v, subtask) + return value + value = value.lower() value = { 'multiwoz': { @@ -132,7 +137,7 @@ def eval_states(gt, pred, subtask): pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value: + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1