diff --git a/README.md b/README.md index 078ba8458250226e5006428f05d6c0f55d226928..1ffdb5c14664c1519ed761194062a032f79b48f1 100755 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ Performance (the first row is the default config for each module. Empty entries | BERTNLU | RuleDST | RulePolicy | **SCLSTM** | 48.5 | 40.2 | 56.9 | 62.3/62.5/58.7 | 11.9/27.1 | | BERTNLU | RuleDST | **MLEPolicy** | TemplateNLG | 42.7 | 35.9 | 17.6 | 62.8/69.8/62.9 | 12.1/24.1 | | BERTNLU | RuleDST | **PGPolicy** | TemplateNLG | 37.4 | 31.7 | 17.4 | 57.4/63.7/56.9 | 11.0/25.3 | -| BERTNLU | RuleDST | **PPOPolicy** | TemplateNLG | 61.1 | 44.0 | 44.6 | 63.9/76.8/67.2 | 12.5/20.8 | +| BERTNLU | RuleDST | **PPOPolicy** | TemplateNLG | 75.5 | 71.7 | 86.6 | 69.4/85.8/74.1 | 13.1/17.8 | | BERTNLU | RuleDST | **GDPLPolicy** | TemplateNLG | 49.4 | 38.4 | 20.1 | 64.5/73.8/65.6 | 11.5/21.3 | | None | **TRADE** | RulePolicy | TemplateNLG | 32.4 | 20.1 | 34.7 | 46.9/48.5/44.0 | 11.4/23.9 | | None | **SUMBT** | RulePolicy | TemplateNLG | 34.5 | 29.4 | 62.4 | 54.1/50.3/48.3 | 11.0/28.1 | @@ -158,7 +158,7 @@ By running `convlab2/policy/evalutate.py --model_name $model` | --------- | ----------------- | | MLE | 0.56 | | PG | 0.54 | -| PPO | 0.74 | +| PPO | 0.89 | | GDPL | 0.58 | ### NLG diff --git a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py index 97a8e33d08ff5cd26d8c9664954edbbc53e5c5e7..7d04bdf43f05e240faa5ef8ebaf53cc5313e5768 100755 --- a/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -776,11 +776,10 @@ class Agenda(object): self.cur_domain = domain def _setdefault_current_domain_by_usraction(self, usr_action): - if self.cur_domain is None: - for diaact in usr_action.keys(): - domain, _ = diaact.split('-') - if domain in ['attraction', 'hotel', 'restaurant', 'taxi', 'train']: - self.cur_domain = domain + for diaact in usr_action.keys(): + domain, _ = diaact.split('-') + if domain in ['attraction', 'hotel', 'restaurant', 'taxi', 'train']: + self.cur_domain = domain def _remove_item(self, diaact, slot=DEF_VAL_UNK): for idx in range(len(self.__stack)):