From 0d45aa65d4cae704972569a801a4375e42075f7f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=BD=97=E5=B4=9A=E9=AA=81?= <function2@qq.com>
Date: Thu, 1 Oct 2020 13:31:34 +0800
Subject: [PATCH] fix XLDST evaluation (#141)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification
---
 convlab2/dst/dstc9/eval_file.py |  3 +--
 convlab2/dst/dstc9/utils.py     | 41 +++++++++++++++++++--------------
 2 files changed, 25 insertions(+), 19 deletions(-)

diff --git a/convlab2/dst/dstc9/eval_file.py b/convlab2/dst/dstc9/eval_file.py
index a89c2e3..4b3e1b2 100644
--- a/convlab2/dst/dstc9/eval_file.py
+++ b/convlab2/dst/dstc9/eval_file.py
@@ -2,9 +2,8 @@
     evaluate output file
 """
 
-import os
 import json
-from copy import deepcopy
+import os
 
 from convlab2.dst.dstc9.utils import prepare_data, extract_gt, eval_states, get_subdir
 
diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index b378fd2..2938f36 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -1,6 +1,7 @@
 import os
 import json
 import zipfile
+from copy import deepcopy
 
 from convlab2 import DATA_ROOT
 
@@ -41,12 +42,10 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
                 sys_utt = turns[i - 1]['content'] if i else None
                 user_utt = turns[i]['content']
                 state = {}
-                for domain_name, domain in turns[i + 1]['sys_state_init'].items():
-                    domain_state = {}
-                    for slot_name, value in domain.items():
-                        if slot_name == 'selectedResults':
-                            continue
-                        domain_state[slot_name] = value
+                for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
+                    selected_results = domain_state.pop('selectedResults')
+                    if selected_results and 'name' in domain_state and not domain_state['name']:
+                        domain_state['name'] = selected_results
                     state[domain_name] = domain_state
                 dialog_data.append((sys_utt, user_utt, state))
             data[dialog_id] = dialog_data
@@ -62,20 +61,28 @@ def extract_gt(test_data):
     return gt
 
 
-def eval_states(gt, pred, subtask):
-    # for unifying values with the same meaning to the same expression
-    value_unifier = {
-        'multiwoz': {
+# for unifying values with the same meaning to the same expression
+def unify_value(value, subtask):
+    if isinstance(value, list):
+        ret = deepcopy(value)
+        for i, v in enumerate(ret):
+            ret[i] = unify_value(v, subtask)
+        return ret
 
+    return {
+        'multiwoz': {
+            '未提及': '',
+            'none': '',
+            '是的': '有',
+            '不是': '没有',
         },
         'crosswoz': {
-            '未提及': '',
+            'None': '',
         }
-    }[subtask]
+    }[subtask].get(value, value)
 
-    def unify_value(value):
-        return value_unifier.get(value, value)
 
+def eval_states(gt, pred, subtask):
     def exception(description, **kargs):
         ret = {
             'status': 'exception',
@@ -107,10 +114,10 @@ def eval_states(gt, pred, subtask):
                 for slot_name, gt_value in gt_domain.items():
                     if slot_name not in pred_domain:
                         return exception('slot missing', dialog_id=dialog_id, turn_id=turn_id, domain=domain_name, slot=slot_name)
-                    gt_value = unify_value(gt_value)
-                    pred_value = unify_value(pred_domain[slot_name])
+                    gt_value = unify_value(gt_value, subtask)
+                    pred_value = unify_value(pred_domain[slot_name], subtask)
                     slot_tot += 1
-                    if gt_value == pred_value:
+                    if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
                         slot_acc += 1
                         if gt_value:
                             tp += 1
-- 
GitLab