From 3812629e8dde63edd4ca5ba94e93d919d50a5d66 Mon Sep 17 00:00:00 2001
From: aaa123git <wandz19@mails.tsinghua.edu.cn>
Date: Thu, 20 May 2021 18:44:47 +0800
Subject: [PATCH] update get_reward (#196)

* update get_reward

* update doc

* update doc
---
 convlab2/dialog_agent/env.py        | 7 +------
 convlab2/dialog_agent/session.py    | 2 +-
 convlab2/evaluator/evaluator.py     | 4 ++++
 convlab2/evaluator/multiwoz_eval.py | 9 +++++++++
 4 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/convlab2/dialog_agent/env.py b/convlab2/dialog_agent/env.py
index e26f489..8cb5616 100755
--- a/convlab2/dialog_agent/env.py
+++ b/convlab2/dialog_agent/env.py
@@ -33,12 +33,7 @@ class Environment():
         state = self.sys_dst.update(dialog_act)
         
         if self.evaluator:
-            if self.evaluator.task_success():
-                reward = 40
-            elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain):
-                reward = 5
-            else:
-                reward = -1
+            reward = self.evaluator.get_reward()
         else:
             reward = self.usr.get_reward()
         terminated = self.usr.is_terminated()
diff --git a/convlab2/dialog_agent/session.py b/convlab2/dialog_agent/session.py
index ffe70de..64fe8df 100755
--- a/convlab2/dialog_agent/session.py
+++ b/convlab2/dialog_agent/session.py
@@ -126,7 +126,7 @@ class BiSession(Session):
             # print('inform prec. {} rec. {} F1 {}'.format(prec, rec, f1))
             # print('book rate {}'.format(self.evaluator.book_rate()))
             # print('task success {}'.format(self.evaluator.task_success()))
-        reward = self.user_agent.get_reward()
+        reward = self.user_agent.get_reward() if self.evaluator is None else self.evaluator.get_reward()
         sys_response = self.next_response(user_response)
         self.dialog_history.append([self.user_agent.name, user_response])
         self.dialog_history.append([self.sys_agent.name, sys_response])
diff --git a/convlab2/evaluator/evaluator.py b/convlab2/evaluator/evaluator.py
index 89a37b3..7857023 100755
--- a/convlab2/evaluator/evaluator.py
+++ b/convlab2/evaluator/evaluator.py
@@ -55,3 +55,7 @@ class Evaluator(object):
     def final_goal_analyze(self):
         """judge whether the final goal satisfies the database constraints"""
         raise NotImplementedError
+
+    def get_reward(self):
+        """returns a reward, which is used for RL training."""
+        raise NotImplementedError
diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py
index e7303c5..11dd0dd 100755
--- a/convlab2/evaluator/multiwoz_eval.py
+++ b/convlab2/evaluator/multiwoz_eval.py
@@ -413,3 +413,12 @@ class MultiWozEvaluator(Evaluator):
             return 1
         else:
             return match / (match + mismatch)
+
+    def get_reward(self):
+        if self.task_success():
+            reward = 40
+        elif self.cur_domain and self.domain_success(self.cur_domain):
+            reward = 5
+        else:
+            reward = -1
+        return reward
-- 
GitLab