Skip to content
Snippets Groups Projects
Commit 64e8a6a8 authored by Christian's avatar Christian
Browse files

the last system act is now updated in environment to be consistent with...

the last system act is now updated in environment to be consistent with pipeline agent and not let the policy change it, which is more errorprone
parent b5b66cfa
Branches change_system_act_in_env
No related tags found
Loading
...@@ -28,6 +28,8 @@ class Environment(): ...@@ -28,6 +28,8 @@ class Environment():
return self.sys_dst.state return self.sys_dst.state
def step(self, action): def step(self, action):
# save last system action
self.sys_dst.state['system_action'] = action
if not self.use_semantic_acts: if not self.use_semantic_acts:
model_response = self.sys_nlg.generate( model_response = self.sys_nlg.generate(
action) if self.sys_nlg else action action) if self.sys_nlg else action
...@@ -49,7 +51,6 @@ class Environment(): ...@@ -49,7 +51,6 @@ class Environment():
self.sys_dst.state['user_action'] = dialog_act self.sys_dst.state['user_action'] = dialog_act
state = self.sys_dst.update(dialog_act) state = self.sys_dst.update(dialog_act)
state = deepcopy(state) state = deepcopy(state)
dialog_act = self.sys_dst.state['user_action']
state['history'].append(["sys", model_response]) state['history'].append(["sys", model_response])
state['history'].append(["usr", observation]) state['history'].append(["usr", observation])
......
...@@ -76,7 +76,6 @@ class GDPL(Policy): ...@@ -76,7 +76,6 @@ class GDPL(Policy):
# print('True :') # print('True :')
# print(a) # print(a)
action = self.vector.action_devectorize(a.detach().numpy()) action = self.vector.action_devectorize(a.detach().numpy())
state['system_action'] = action
self.info_dict["action_used"] = action self.info_dict["action_used"] = action
# for key in state.keys(): # for key in state.keys():
# print("Key : {} , Value : {}".format(key,state[key])) # print("Key : {} , Value : {}".format(key,state[key]))
......
...@@ -74,7 +74,6 @@ class PG(Policy): ...@@ -74,7 +74,6 @@ class PG(Policy):
# print('True :') # print('True :')
# print(a) # print(a)
action = self.vector.action_devectorize(a.detach().numpy()) action = self.vector.action_devectorize(a.detach().numpy())
state['system_action'] = action
self.info_dict["action_used"] = action self.info_dict["action_used"] = action
# for key in state.keys(): # for key in state.keys():
# print("Key : {} , Value : {}".format(key,state[key])) # print("Key : {} , Value : {}".format(key,state[key]))
......
...@@ -85,7 +85,6 @@ class PPO(Policy): ...@@ -85,7 +85,6 @@ class PPO(Policy):
# print('True :') # print('True :')
# print(a) # print(a)
action = self.vector.action_devectorize(a.detach().numpy()) action = self.vector.action_devectorize(a.detach().numpy())
state['system_action'] = action
self.info_dict["action_used"] = action self.info_dict["action_used"] = action
# for key in state.keys(): # for key in state.keys():
# print("Key : {} , Value : {}".format(key,state[key])) # print("Key : {} , Value : {}".format(key,state[key]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment