From 3812629e8dde63edd4ca5ba94e93d919d50a5d66 Mon Sep 17 00:00:00 2001 From: aaa123git <wandz19@mails.tsinghua.edu.cn> Date: Thu, 20 May 2021 18:44:47 +0800 Subject: [PATCH] update get_reward (#196) * update get_reward * update doc * update doc --- convlab2/dialog_agent/env.py | 7 +------ convlab2/dialog_agent/session.py | 2 +- convlab2/evaluator/evaluator.py | 4 ++++ convlab2/evaluator/multiwoz_eval.py | 9 +++++++++ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/convlab2/dialog_agent/env.py b/convlab2/dialog_agent/env.py index e26f489..8cb5616 100755 --- a/convlab2/dialog_agent/env.py +++ b/convlab2/dialog_agent/env.py @@ -33,12 +33,7 @@ class Environment(): state = self.sys_dst.update(dialog_act) if self.evaluator: - if self.evaluator.task_success(): - reward = 40 - elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain): - reward = 5 - else: - reward = -1 + reward = self.evaluator.get_reward() else: reward = self.usr.get_reward() terminated = self.usr.is_terminated() diff --git a/convlab2/dialog_agent/session.py b/convlab2/dialog_agent/session.py index ffe70de..64fe8df 100755 --- a/convlab2/dialog_agent/session.py +++ b/convlab2/dialog_agent/session.py @@ -126,7 +126,7 @@ class BiSession(Session): # print('inform prec. {} rec. {} F1 {}'.format(prec, rec, f1)) # print('book rate {}'.format(self.evaluator.book_rate())) # print('task success {}'.format(self.evaluator.task_success())) - reward = self.user_agent.get_reward() + reward = self.user_agent.get_reward() if self.evaluator is None else self.evaluator.get_reward() sys_response = self.next_response(user_response) self.dialog_history.append([self.user_agent.name, user_response]) self.dialog_history.append([self.sys_agent.name, sys_response]) diff --git a/convlab2/evaluator/evaluator.py b/convlab2/evaluator/evaluator.py index 89a37b3..7857023 100755 --- a/convlab2/evaluator/evaluator.py +++ b/convlab2/evaluator/evaluator.py @@ -55,3 +55,7 @@ class Evaluator(object): def final_goal_analyze(self): """judge whether the final goal satisfies the database constraints""" raise NotImplementedError + + def get_reward(self): + """returns a reward, which is used for RL training.""" + raise NotImplementedError diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index e7303c5..11dd0dd 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -413,3 +413,12 @@ class MultiWozEvaluator(Evaluator): return 1 else: return match / (match + mismatch) + + def get_reward(self): + if self.task_success(): + reward = 40 + elif self.cur_domain and self.domain_success(self.cur_domain): + reward = 5 + else: + reward = -1 + return reward -- GitLab