From 341c1198d6613c668a1a37bbc46d1bae5c119810 Mon Sep 17 00:00:00 2001 From: function2 <function2@qq.com> Date: Sun, 1 Nov 2020 17:24:42 +0800 Subject: [PATCH] fix version issue --- convlab2/dst/dstc9/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py index 3b706c5..cfa1f4d 100644 --- a/convlab2/dst/dstc9/utils.py +++ b/convlab2/dst/dstc9/utils.py @@ -78,6 +78,11 @@ def extract_gt(test_data): # for unifying values with the same meaning to the same expression def unify_value(value, subtask): + if isinstance(value, list): + for i, v in enumerate(value): + value[i] = unify_value(v, subtask) + return value + value = value.lower() value = { 'multiwoz': { @@ -132,7 +137,7 @@ def eval_states(gt, pred, subtask): pred_value = unify_value(pred_domain[slot_name], subtask) slot_tot += 1 - if gt_value == pred_value: + if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value: slot_acc += 1 if gt_value: tp += 1 -- GitLab