From e80628bb9b716f4a236e87bd3c871e58d965df40 Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Thu, 1 Oct 2020 10:59:25 +0800
Subject: [PATCH] add value unification

---
 convlab2/dst/dstc9/utils.py | 29 ++++++++++++++++++-----------
 1 file changed, 18 insertions(+), 11 deletions(-)

diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index c926103..2938f36 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -1,6 +1,7 @@
 import os
 import json
 import zipfile
+from copy import deepcopy
 
 from convlab2 import DATA_ROOT
 
@@ -60,22 +61,28 @@ def extract_gt(test_data):
     return gt
 
 
-def eval_states(gt, pred, subtask):
-    # for unifying values with the same meaning to the same expression
-    value_unifier = {
+# for unifying values with the same meaning to the same expression
+def unify_value(value, subtask):
+    if isinstance(value, list):
+        ret = deepcopy(value)
+        for i, v in enumerate(ret):
+            ret[i] = unify_value(v, subtask)
+        return ret
+
+    return {
         'multiwoz': {
             '未提及': '',
+            'none': '',
+            '是的': '有',
+            '不是': '没有',
         },
         'crosswoz': {
-
+            'None': '',
         }
-    }[subtask]
+    }[subtask].get(value, value)
 
-    def unify_value(value):
-        if isinstance(value, list):
-            return list(map(unify_value, value))
-        return value_unifier.get(value, value)
 
+def eval_states(gt, pred, subtask):
     def exception(description, **kargs):
         ret = {
             'status': 'exception',
@@ -107,8 +114,8 @@ def eval_states(gt, pred, subtask):
                 for slot_name, gt_value in gt_domain.items():
                     if slot_name not in pred_domain:
                         return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
-                    gt_value = unify_value(gt_value)
-                    pred_value = unify_value(pred_domain[slot_name])
+                    gt_value = unify_value(gt_value, subtask)
+                    pred_value = unify_value(pred_domain[slot_name], subtask)
                     slot_tot += 1
                     if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
                         slot_acc += 1
-- 
GitLab