Skip to content
Snippets Groups Projects
Unverified Commit 3812629e authored by aaa123git's avatar aaa123git Committed by GitHub
Browse files

update get_reward (#196)

* update get_reward

* update doc

* update doc
parent 2236ae82
Branches
No related tags found
No related merge requests found
...@@ -33,12 +33,7 @@ class Environment(): ...@@ -33,12 +33,7 @@ class Environment():
state = self.sys_dst.update(dialog_act) state = self.sys_dst.update(dialog_act)
if self.evaluator: if self.evaluator:
if self.evaluator.task_success(): reward = self.evaluator.get_reward()
reward = 40
elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain):
reward = 5
else:
reward = -1
else: else:
reward = self.usr.get_reward() reward = self.usr.get_reward()
terminated = self.usr.is_terminated() terminated = self.usr.is_terminated()
......
...@@ -126,7 +126,7 @@ class BiSession(Session): ...@@ -126,7 +126,7 @@ class BiSession(Session):
# print('inform prec. {} rec. {} F1 {}'.format(prec, rec, f1)) # print('inform prec. {} rec. {} F1 {}'.format(prec, rec, f1))
# print('book rate {}'.format(self.evaluator.book_rate())) # print('book rate {}'.format(self.evaluator.book_rate()))
# print('task success {}'.format(self.evaluator.task_success())) # print('task success {}'.format(self.evaluator.task_success()))
reward = self.user_agent.get_reward() reward = self.user_agent.get_reward() if self.evaluator is None else self.evaluator.get_reward()
sys_response = self.next_response(user_response) sys_response = self.next_response(user_response)
self.dialog_history.append([self.user_agent.name, user_response]) self.dialog_history.append([self.user_agent.name, user_response])
self.dialog_history.append([self.sys_agent.name, sys_response]) self.dialog_history.append([self.sys_agent.name, sys_response])
......
...@@ -55,3 +55,7 @@ class Evaluator(object): ...@@ -55,3 +55,7 @@ class Evaluator(object):
def final_goal_analyze(self): def final_goal_analyze(self):
"""judge whether the final goal satisfies the database constraints""" """judge whether the final goal satisfies the database constraints"""
raise NotImplementedError raise NotImplementedError
def get_reward(self):
"""returns a reward, which is used for RL training."""
raise NotImplementedError
...@@ -413,3 +413,12 @@ class MultiWozEvaluator(Evaluator): ...@@ -413,3 +413,12 @@ class MultiWozEvaluator(Evaluator):
return 1 return 1
else: else:
return match / (match + mismatch) return match / (match + mismatch)
def get_reward(self):
if self.task_success():
reward = 40
elif self.cur_domain and self.domain_success(self.cur_domain):
reward = 5
else:
reward = -1
return reward
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment