From f822c20f5b5d703019c4bd0048b9e84d96038887 Mon Sep 17 00:00:00 2001
From: aaa123git <wandz19@mails.tsinghua.edu.cn>
Date: Fri, 30 Jul 2021 11:28:54 +0800
Subject: [PATCH] fix Issue 207 and update evaluation results of PPO (#211)

* fix issue 207

* update evaluation results of PPO
---
 README.md                                               | 4 ++--
 convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py | 9 ++++-----
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/README.md b/README.md
index 078ba845..1ffdb5c1 100755
--- a/README.md
+++ b/README.md
@@ -115,7 +115,7 @@ Performance (the first row is the default config for each module. Empty entries
 | BERTNLU | RuleDST | RulePolicy | **SCLSTM**  |   48.5    | 40.2 | 56.9   | 62.3/62.5/58.7 |  11.9/27.1         |
 | BERTNLU     | RuleDST | **MLEPolicy**  | TemplateNLG |     42.7          |    35.9      |  17.6   | 62.8/69.8/62.9  |  12.1/24.1    |
 | BERTNLU | RuleDST | **PGPolicy**   | TemplateNLG |     37.4         |    31.7     |   17.4  |  57.4/63.7/56.9  |   11.0/25.3    |
-| BERTNLU | RuleDST | **PPOPolicy**  | TemplateNLG |     61.1         |    44.0    |   44.6    | 63.9/76.8/67.2  |  12.5/20.8   |
+| BERTNLU | RuleDST | **PPOPolicy**  | TemplateNLG |     75.5         |    71.7    |   86.6    | 69.4/85.8/74.1  |  13.1/17.8   |
 | BERTNLU | RuleDST | **GDPLPolicy** | TemplateNLG |     49.4         |     38.4    |  20.1     |  64.5/73.8/65.6 |  11.5/21.3    |
 | None        | **TRADE** | RulePolicy | TemplateNLG |    32.4      |    20.1     |    34.7      |  46.9/48.5/44.0 |  11.4/23.9      |
 | None        | **SUMBT** | RulePolicy | TemplateNLG |   34.5       |   29.4     |   62.4    |  54.1/50.3/48.3  |   11.0/28.1     |
@@ -158,7 +158,7 @@ By running `convlab2/policy/evalutate.py --model_name $model`
 | --------- | ----------------- |
 | MLE       | 0.56              |
 | PG        | 0.54              |
-| PPO       | 0.74              |
+| PPO       | 0.89              |
 | GDPL      | 0.58              |
 
 ### NLG
diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
index 97a8e33d..7d04bdf4 100755
--- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -776,11 +776,10 @@ class Agenda(object):
                 self.cur_domain = domain
 
     def _setdefault_current_domain_by_usraction(self, usr_action):
-        if self.cur_domain is None:
-            for diaact in usr_action.keys():
-                domain, _ = diaact.split('-')
-                if domain in ['attraction', 'hotel', 'restaurant', 'taxi', 'train']:
-                    self.cur_domain = domain
+        for diaact in usr_action.keys():
+            domain, _ = diaact.split('-')
+            if domain in ['attraction', 'hotel', 'restaurant', 'taxi', 'train']:
+                self.cur_domain = domain
 
     def _remove_item(self, diaact, slot=DEF_VAL_UNK):
         for idx in range(len(self.__stack)):
-- 
GitLab