From 64e8a6a8a3d83ddec5a25a909f0499326f1e46cb Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Tue, 26 Jul 2022 16:18:33 +0200 Subject: [PATCH] the last system act is now updated in environment to be consistent with pipeline agent and not let the policy change it, which is more errorprone --- convlab/dialog_agent/env.py | 3 ++- convlab/policy/gdpl/gdpl.py | 1 - convlab/policy/pg/pg.py | 1 - convlab/policy/ppo/ppo.py | 1 - 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index c1f15dfa..6216eaaa 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -28,6 +28,8 @@ class Environment(): return self.sys_dst.state def step(self, action): + # save last system action + self.sys_dst.state['system_action'] = action if not self.use_semantic_acts: model_response = self.sys_nlg.generate( action) if self.sys_nlg else action @@ -49,7 +51,6 @@ class Environment(): self.sys_dst.state['user_action'] = dialog_act state = self.sys_dst.update(dialog_act) state = deepcopy(state) - dialog_act = self.sys_dst.state['user_action'] state['history'].append(["sys", model_response]) state['history'].append(["usr", observation]) diff --git a/convlab/policy/gdpl/gdpl.py b/convlab/policy/gdpl/gdpl.py index b25f2c68..eacf80ef 100755 --- a/convlab/policy/gdpl/gdpl.py +++ b/convlab/policy/gdpl/gdpl.py @@ -76,7 +76,6 @@ class GDPL(Policy): # print('True :') # print(a) action = self.vector.action_devectorize(a.detach().numpy()) - state['system_action'] = action self.info_dict["action_used"] = action # for key in state.keys(): # print("Key : {} , Value : {}".format(key,state[key])) diff --git a/convlab/policy/pg/pg.py b/convlab/policy/pg/pg.py index 249bc382..dd47f44f 100755 --- a/convlab/policy/pg/pg.py +++ b/convlab/policy/pg/pg.py @@ -74,7 +74,6 @@ class PG(Policy): # print('True :') # print(a) action = self.vector.action_devectorize(a.detach().numpy()) - state['system_action'] = action self.info_dict["action_used"] = action # for key in state.keys(): # print("Key : {} , Value : {}".format(key,state[key])) diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py index 630ec3e6..7ceefcd7 100755 --- a/convlab/policy/ppo/ppo.py +++ b/convlab/policy/ppo/ppo.py @@ -85,7 +85,6 @@ class PPO(Policy): # print('True :') # print(a) action = self.vector.action_devectorize(a.detach().numpy()) - state['system_action'] = action self.info_dict["action_used"] = action # for key in state.keys(): # print("Key : {} , Value : {}".format(key,state[key])) -- GitLab