From c761fc7b4e126bb83cc91edd4025027a6da724f4 Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Thu, 29 Oct 2020 10:56:15 +0800
Subject: [PATCH] Revert "update eval"

This reverts commit 02537cf8f6474a33bb2d35e640e7f9d9b5b86f52.
---
 convlab2/dst/dstc9/utils.py | 32 ++++++++++++++------------------
 1 file changed, 14 insertions(+), 18 deletions(-)

diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index 06615fb..7c88888 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -37,26 +37,16 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
         for dialog_id, dialog in test_data.items():
             dialog_data = []
             turns = dialog['messages']
-            selected_results = {k: [] for k in turns[1]['sys_state'].keys()}
             for i in range(0, len(turns), 2):
                 sys_utt = turns[i - 1]['content'] if i else None
                 user_utt = turns[i]['content']
-                dialog_state = {}
-                for domain_name, state in turns[i + 1]['sys_state_init'].items():
-                    state.pop('selectedResults')
-                    sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
-                    # if state has changed compared to previous sys state
-                    state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
-                    # clear the outdated previous selected results if state has been updated
-                    if state_change:
-                        selected_results[domain_name].clear()
-                    if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
-                        state['name'] = selected_results[domain_name][0]
-                    dialog_state[domain_name] = state
-                    if state_change and sys_selected_results:
-                        selected_results[domain_name] = sys_selected_results
-
-                dialog_data.append((sys_utt, user_utt, dialog_state))
+                state = {}
+                for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
+                    selected_results = domain_state.pop('selectedResults')
+                    if selected_results and 'name' in domain_state and not domain_state['name']:
+                        domain_state['name'] = selected_results
+                    state[domain_name] = domain_state
+                dialog_data.append((sys_utt, user_utt, state))
             data[dialog_id] = dialog_data
 
     return data
@@ -72,6 +62,12 @@ def extract_gt(test_data):
 
 # for unifying values with the same meaning to the same expression
 def unify_value(value, subtask):
+    if isinstance(value, list):
+        ret = deepcopy(value)
+        for i, v in enumerate(ret):
+            ret[i] = unify_value(v, subtask)
+        return ret
+
     value = value.lower()
     value = {
         'multiwoz': {
@@ -126,7 +122,7 @@ def eval_states(gt, pred, subtask):
                     pred_value = unify_value(pred_domain[slot_name], subtask)
                     slot_tot += 1
 
-                    if gt_value == pred_value:
+                    if gt_value == pred_value or isinstance(gt_value, list) and pred_value in gt_value:
                         slot_acc += 1
                         if gt_value:
                             tp += 1
-- 
GitLab