From 67bf7897ab9d71dd97c08a7bb666bad5beb44230 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=BD=97=E5=B4=9A=E9=AA=81=28Lingxiao=20Luo=29?=
 <function2@qq.com>
Date: Fri, 2 Apr 2021 15:21:48 +0800
Subject: [PATCH] Update legacy XLDST evaluation (#186)

* remove transformer cache dir

* process data

* fix "book" slots processing in MultiWOZ-zh SUMBT, update evaluation results #185

* Revert "process data"

This reverts commit d17602c23cccb482827d8892554f32eb69dde297.

* Revert "remove transformer cache dir"

This reverts commit 35873129eb8d45a5bebada63b4549de88b665873.
---
 README.md                               |  6 +++---
 convlab2/dst/sumbt/multiwoz_zh/sumbt.py | 14 +++-----------
 2 files changed, 6 insertions(+), 14 deletions(-)

diff --git a/README.md b/README.md
index ffd1e93..01408f9 100755
--- a/README.md
+++ b/README.md
@@ -204,9 +204,9 @@ evaluation of our pre-trained models are: (joint acc.)
 
 | type  | CrossWOZ-en | MultiWOZ-zh |
 | ----- | ----------- | ----------- |
-| val   | 12.4%       | 45.1%       |
-| test  | 12.4%       | 43.5%       |
-| human_val | 10.6%       | 49.4%       |
+| val   | 12.4%       | 48.5%       |
+| test  | 12.4%       | 46.0%       |
+| human_val | 10.6%       | 47.4%       |
 
 `human_val` option will make the model evaluate on the validation set translated by human. 
 
diff --git a/convlab2/dst/sumbt/multiwoz_zh/sumbt.py b/convlab2/dst/sumbt/multiwoz_zh/sumbt.py
index 96e83f0..047519f 100644
--- a/convlab2/dst/sumbt/multiwoz_zh/sumbt.py
+++ b/convlab2/dst/sumbt/multiwoz_zh/sumbt.py
@@ -582,28 +582,20 @@ class SUMBTTracker(DST):
         new_belief_state = copy.deepcopy(prev_state['belief_state'])
         for state in pred_states:
             domain, slot, value = state.split('-', 2)
-            
-            if slot not in ['name', 'book']:
-                if domain not in new_belief_state:
-                    if domain == 'bus':
-                        continue
-                    else:
-                        raise Exception(
-                            'Error: domain <{}> not in belief state'.format(domain))
             # slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
             assert 'semi' in new_belief_state[domain]
             assert 'book' in new_belief_state[domain]
+            domain_dic = new_belief_state[domain]
             if '预订' in slot:
                 assert slot.startswith('预订')
+                slot = slot[2:]
+                assert slot in domain_dic['book']
 
-            domain_dic = new_belief_state[domain]
             if slot in domain_dic['semi']:
                 new_belief_state[domain]['semi'][slot] = value
                 # normalize_value(self.value_dict, domain, slot, value)
             elif slot in domain_dic['book']:
                 new_belief_state[domain]['book'][slot] = value
-            elif slot.lower() in domain_dic['book']:
-                new_belief_state[domain]['book'][slot.lower()] = value
             else:
                 with open('trade_tracker_unknown_slot.log', 'a+') as f:
                     f.write(
-- 
GitLab