From 64e8a6a8a3d83ddec5a25a909f0499326f1e46cb Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Tue, 26 Jul 2022 16:18:33 +0200
Subject: [PATCH] the last system act is now updated in environment to be
 consistent with pipeline agent and not let the policy change it, which is
 more errorprone

---
 convlab/dialog_agent/env.py | 3 ++-
 convlab/policy/gdpl/gdpl.py | 1 -
 convlab/policy/pg/pg.py     | 1 -
 convlab/policy/ppo/ppo.py   | 1 -
 4 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py
index c1f15dfa..6216eaaa 100755
--- a/convlab/dialog_agent/env.py
+++ b/convlab/dialog_agent/env.py
@@ -28,6 +28,8 @@ class Environment():
         return self.sys_dst.state
 
     def step(self, action):
+        # save last system action
+        self.sys_dst.state['system_action'] = action
         if not self.use_semantic_acts:
             model_response = self.sys_nlg.generate(
                 action) if self.sys_nlg else action
@@ -49,7 +51,6 @@ class Environment():
         self.sys_dst.state['user_action'] = dialog_act
         state = self.sys_dst.update(dialog_act)
         state = deepcopy(state)
-        dialog_act = self.sys_dst.state['user_action']
 
         state['history'].append(["sys", model_response])
         state['history'].append(["usr", observation])
diff --git a/convlab/policy/gdpl/gdpl.py b/convlab/policy/gdpl/gdpl.py
index b25f2c68..eacf80ef 100755
--- a/convlab/policy/gdpl/gdpl.py
+++ b/convlab/policy/gdpl/gdpl.py
@@ -76,7 +76,6 @@ class GDPL(Policy):
         # print('True :')
         # print(a)
         action = self.vector.action_devectorize(a.detach().numpy())
-        state['system_action'] = action
         self.info_dict["action_used"] = action
         # for key in state.keys():
         #     print("Key : {} , Value : {}".format(key,state[key]))
diff --git a/convlab/policy/pg/pg.py b/convlab/policy/pg/pg.py
index 249bc382..dd47f44f 100755
--- a/convlab/policy/pg/pg.py
+++ b/convlab/policy/pg/pg.py
@@ -74,7 +74,6 @@ class PG(Policy):
         # print('True :')
         # print(a)
         action = self.vector.action_devectorize(a.detach().numpy())
-        state['system_action'] = action
         self.info_dict["action_used"] = action
         # for key in state.keys():
         #     print("Key : {} , Value : {}".format(key,state[key]))
diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py
index 630ec3e6..7ceefcd7 100755
--- a/convlab/policy/ppo/ppo.py
+++ b/convlab/policy/ppo/ppo.py
@@ -85,7 +85,6 @@ class PPO(Policy):
         # print('True :')
         # print(a)
         action = self.vector.action_devectorize(a.detach().numpy())
-        state['system_action'] = action
         self.info_dict["action_used"] = action
         # for key in state.keys():
         #     print("Key : {} , Value : {}".format(key,state[key]))
-- 
GitLab