diff --git a/convlab2/dialog_agent/env.py b/convlab2/dialog_agent/env.py index e26f4891396a69c2f072b10def2bc86d419aa930..8cb56163241e5378fc639c379fc5dce1b0bb77d7 100755 --- a/convlab2/dialog_agent/env.py +++ b/convlab2/dialog_agent/env.py @@ -33,12 +33,7 @@ class Environment(): state = self.sys_dst.update(dialog_act) if self.evaluator: - if self.evaluator.task_success(): - reward = 40 - elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain): - reward = 5 - else: - reward = -1 + reward = self.evaluator.get_reward() else: reward = self.usr.get_reward() terminated = self.usr.is_terminated() diff --git a/convlab2/dialog_agent/session.py b/convlab2/dialog_agent/session.py index ffe70de297fcf9c8aa11c47ce1e880fe50db5da9..64fe8dfb8c197a8441d28ac10dd64753cbb7734f 100755 --- a/convlab2/dialog_agent/session.py +++ b/convlab2/dialog_agent/session.py @@ -126,7 +126,7 @@ class BiSession(Session): # print('inform prec. {} rec. {} F1 {}'.format(prec, rec, f1)) # print('book rate {}'.format(self.evaluator.book_rate())) # print('task success {}'.format(self.evaluator.task_success())) - reward = self.user_agent.get_reward() + reward = self.user_agent.get_reward() if self.evaluator is None else self.evaluator.get_reward() sys_response = self.next_response(user_response) self.dialog_history.append([self.user_agent.name, user_response]) self.dialog_history.append([self.sys_agent.name, sys_response]) diff --git a/convlab2/evaluator/evaluator.py b/convlab2/evaluator/evaluator.py index 89a37b35bc77973e8ab381b066b4d7a77d4615f4..78570239a0cf926dea0f3913a61d3e57b45d7d43 100755 --- a/convlab2/evaluator/evaluator.py +++ b/convlab2/evaluator/evaluator.py @@ -55,3 +55,7 @@ class Evaluator(object): def final_goal_analyze(self): """judge whether the final goal satisfies the database constraints""" raise NotImplementedError + + def get_reward(self): + """returns a reward, which is used for RL training.""" + raise NotImplementedError diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index e7303c56d204eed52bf9deb3d55a5a258f9a1ff1..11dd0ddd2030431abd1577fd2c0b9251575d83c0 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -413,3 +413,12 @@ class MultiWozEvaluator(Evaluator): return 1 else: return match / (match + mismatch) + + def get_reward(self): + if self.task_success(): + reward = 40 + elif self.cur_domain and self.domain_success(self.cur_domain): + reward = 5 + else: + reward = -1 + return reward