diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index c1f15dfa031c4c72091cf9418e008c25bd04d804..6216eaaac9fb81615f903f048d6d85766ce663c5 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -28,6 +28,8 @@ class Environment(): return self.sys_dst.state def step(self, action): + # save last system action + self.sys_dst.state['system_action'] = action if not self.use_semantic_acts: model_response = self.sys_nlg.generate( action) if self.sys_nlg else action @@ -49,7 +51,6 @@ class Environment(): self.sys_dst.state['user_action'] = dialog_act state = self.sys_dst.update(dialog_act) state = deepcopy(state) - dialog_act = self.sys_dst.state['user_action'] state['history'].append(["sys", model_response]) state['history'].append(["usr", observation]) diff --git a/convlab/policy/gdpl/gdpl.py b/convlab/policy/gdpl/gdpl.py index b25f2c6847607577a2620eb609d686724227def0..eacf80efa209b9aa43dc78b9cb5b8bf697afb618 100755 --- a/convlab/policy/gdpl/gdpl.py +++ b/convlab/policy/gdpl/gdpl.py @@ -76,7 +76,6 @@ class GDPL(Policy): # print('True :') # print(a) action = self.vector.action_devectorize(a.detach().numpy()) - state['system_action'] = action self.info_dict["action_used"] = action # for key in state.keys(): # print("Key : {} , Value : {}".format(key,state[key])) diff --git a/convlab/policy/pg/pg.py b/convlab/policy/pg/pg.py index 249bc382df65d5fc13b761f55f0c6b19fa23d304..dd47f44f0e733cf43fee69db135dbfef5f0ff614 100755 --- a/convlab/policy/pg/pg.py +++ b/convlab/policy/pg/pg.py @@ -74,7 +74,6 @@ class PG(Policy): # print('True :') # print(a) action = self.vector.action_devectorize(a.detach().numpy()) - state['system_action'] = action self.info_dict["action_used"] = action # for key in state.keys(): # print("Key : {} , Value : {}".format(key,state[key])) diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py index 630ec3e67dceddec15217d3fb66489db19543429..7ceefcd77a4f07a1f42c2d4a368715969d3cdf53 100755 --- a/convlab/policy/ppo/ppo.py +++ b/convlab/policy/ppo/ppo.py @@ -85,7 +85,6 @@ class PPO(Policy): # print('True :') # print(a) action = self.vector.action_devectorize(a.detach().numpy()) - state['system_action'] = action self.info_dict["action_used"] = action # for key in state.keys(): # print("Key : {} , Value : {}".format(key,state[key]))