From 62a7d11bd0b3a86bc145099c5c7e4098a0b62a35 Mon Sep 17 00:00:00 2001
From: function2 <function2@qq.com>
Date: Wed, 28 Oct 2020 22:15:18 +0800
Subject: [PATCH] update eval

---
 convlab2/dst/dstc9/utils.py | 40 ++++++++++++++++++-------------------
 1 file changed, 20 insertions(+), 20 deletions(-)

diff --git a/convlab2/dst/dstc9/utils.py b/convlab2/dst/dstc9/utils.py
index 61d0a5f..06615fb 100644
--- a/convlab2/dst/dstc9/utils.py
+++ b/convlab2/dst/dstc9/utils.py
@@ -1,7 +1,6 @@
-import os
 import json
+import os
 import zipfile
-from copy import deepcopy
 
 from convlab2 import DATA_ROOT
 
@@ -23,40 +22,41 @@ def prepare_data(subtask, split, data_root=DATA_ROOT):
             for i in range(0, len(turns), 2):
                 sys_utt = turns[i - 1]['text'] if i else None
                 user_utt = turns[i]['text']
-                state = {}
+                dialog_state = {}
                 for domain_name, domain in turns[i + 1]['metadata'].items():
                     if domain_name in ['警察机关', '医院', '公共汽车']:
                         continue
-                    domain_state = {}
+                    state = {}
                     for slots in domain.values():
                         for slot_name, value in slots.items():
-                            domain_state[slot_name] = value
-                    state[domain_name] = domain_state
-                dialog_data.append((sys_utt, user_utt, state))
+                            state[slot_name] = value
+                    dialog_state[domain_name] = state
+                dialog_data.append((sys_utt, user_utt, dialog_state))
             data[dialog_id] = dialog_data
     else:
         for dialog_id, dialog in test_data.items():
             dialog_data = []
             turns = dialog['messages']
-            selected_results = {k: [] for k in turns[1]['sys_state_init'].keys()}
+            selected_results = {k: [] for k in turns[1]['sys_state'].keys()}
             for i in range(0, len(turns), 2):
                 sys_utt = turns[i - 1]['content'] if i else None
                 user_utt = turns[i]['content']
-                state = {}
-                for domain_name, domain_state in turns[i + 1]['sys_state_init'].items():
-                    new_selected_results = domain_state.pop('selectedResults')
-                    # if state has changed compared to previous turn
-                    state_change = i == 0 or domain_state != dialog_data[-1][2][domain_name]
-                    # clear the invalid previous selected results if state has changed
+                dialog_state = {}
+                for domain_name, state in turns[i + 1]['sys_state_init'].items():
+                    state.pop('selectedResults')
+                    sys_selected_results = turns[i + 1]['sys_state'][domain_name].pop('selectedResults')
+                    # if state has changed compared to previous sys state
+                    state_change = i == 0 or state != turns[i - 1]['sys_state'][domain_name]
+                    # clear the outdated previous selected results if state has been updated
                     if state_change:
                         selected_results[domain_name].clear()
-                    if not domain_state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
-                        domain_state['name'] = selected_results[domain_name][0]
-                    state[domain_name] = domain_state
-                    if state_change:
-                        selected_results[domain_name] = new_selected_results
+                    if not state.get('name', 'something nonempty') and len(selected_results[domain_name]) == 1:
+                        state['name'] = selected_results[domain_name][0]
+                    dialog_state[domain_name] = state
+                    if state_change and sys_selected_results:
+                        selected_results[domain_name] = sys_selected_results
 
-                dialog_data.append((sys_utt, user_utt, state))
+                dialog_data.append((sys_utt, user_utt, dialog_state))
             data[dialog_id] = dialog_data
 
     return data
-- 
GitLab