From 806356d0bb1f7aaad77cc676d2bfa4b1072ccb6d Mon Sep 17 00:00:00 2001
From: newRuntieException <wdz15@mails.tsinghua.edu.cn>
Date: Tue, 25 Aug 2020 13:16:04 +0800
Subject: [PATCH] output goal analysis to file

---
 convlab2/evaluator/evaluator.py         |  4 ++++
 convlab2/evaluator/multiwoz_eval.py     |  2 ++
 convlab2/policy/evaluate.py             |  1 +
 convlab2/util/analysis_tool/analyzer.py | 13 +++++++++++++
 4 files changed, 20 insertions(+)

diff --git a/convlab2/evaluator/evaluator.py b/convlab2/evaluator/evaluator.py
index 9dfaee0..89a37b3 100755
--- a/convlab2/evaluator/evaluator.py
+++ b/convlab2/evaluator/evaluator.py
@@ -51,3 +51,7 @@ class Evaluator(object):
         judge if the domain (subtask) is successfully completed
         """
         raise NotImplementedError
+
+    def final_goal_analyze(self):
+        """judge whether the final goal satisfies the database constraints"""
+        raise NotImplementedError
diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py
index fd84621..78edc42 100755
--- a/convlab2/evaluator/multiwoz_eval.py
+++ b/convlab2/evaluator/multiwoz_eval.py
@@ -396,6 +396,8 @@ class MultiWozEvaluator(Evaluator):
         return match, mismatch
 
     def final_goal_analyze(self):
+        """percentage of domains, in which the final goal satisfies the database constraints.
+        If there is no dialog action, returns 1."""
         match, mismatch = self._final_goal_analyze()
         if match == mismatch == 0:
             return 1
diff --git a/convlab2/policy/evaluate.py b/convlab2/policy/evaluate.py
index 87c7b02..8085bf4 100755
--- a/convlab2/policy/evaluate.py
+++ b/convlab2/policy/evaluate.py
@@ -213,6 +213,7 @@ def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
                     logging.info(f'task success: {task_succ}')
                     logging.info(f'book rate: {sess.evaluator.book_rate()}')
                     logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
+                    logging.info(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}")
                     logging.info('-'*50)
                     break
             else: 
diff --git a/convlab2/util/analysis_tool/analyzer.py b/convlab2/util/analysis_tool/analyzer.py
index eeae31c..8cd3ebc 100755
--- a/convlab2/util/analysis_tool/analyzer.py
+++ b/convlab2/util/analysis_tool/analyzer.py
@@ -50,6 +50,7 @@ class Analyzer:
         print('task success:', sess.evaluator.task_success())
         print('book rate:', sess.evaluator.book_rate())
         print('inform precision/recall/f1:', sess.evaluator.inform_F1())
+        print(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}")
         print('-' * 50)
         print('final goal:')
         pprint(sess.evaluator.goal)
@@ -67,6 +68,8 @@ class Analyzer:
         complete_num = 0
         turn_num = 0
         turn_suc_num = 0
+        num_domains = 0
+        num_domains_satisfying_constraints = 0
 
         reporter = Reporter(model_name)
         logger = logging.getLogger(__name__)
@@ -142,6 +145,7 @@ class Analyzer:
             task_complete = sess.user_agent.policy.policy.goal.task_complete()
             book_rate = sess.evaluator.book_rate()
             stats = sess.evaluator.inform_F1()
+            percentage = sess.evaluator.final_goal_analyze()
             if task_success:
                 suc_num += 1
                 turn_suc_num += step
@@ -153,6 +157,9 @@ class Analyzer:
                 f1.append(stats[2])
             if book_rate is not None:
                 match.append(book_rate)
+            if len(sess.evaluator.goal) > 0:
+                num_domains += len(sess.evaluator.goal)
+                num_domains_satisfying_constraints += len(sess.evaluator.goal) * percentage
             if (j+1) % 100 == 0:
                 logger.info("model name %s", model_name)
                 logger.info("dialogue %d", j+1)
@@ -161,6 +168,8 @@ class Analyzer:
                 logger.info('task success: %.3f', suc_num/(j+1))
                 logger.info('book rate: %.3f', np.mean(match))
                 logger.info('inform precision/recall/f1: %.3f %.3f %.3f', np.mean(precision), np.mean(recall), np.mean(f1))
+                logging.info("percentage of domains that satisfies the database constraints: %.3f}" % \
+                             (1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)))
             domain_set = []
             for da in sess.evaluator.usr_da_array:
                 if da.split('-')[0] != 'general' and da.split('-')[0] not in domain_set:
@@ -195,6 +204,8 @@ class Analyzer:
         print('average book rate:', np.mean(match))
         print("average turn (succ):", tmp)
         print("average turn (all):", turn_num / total_dialog)
+        print("percentage of domains that satisfies the database constraints: %.3f}" % \
+              (1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)))
         print("=" * 100)
         print("complete number of dialogs/tot:", complete_num / total_dialog, file=f)
         print("success number of dialogs/tot:", suc_num / total_dialog, file=f)
@@ -204,6 +215,8 @@ class Analyzer:
         print('average book rate:', np.mean(match), file=f)
         print("average turn (succ):", tmp, file=f)
         print("average turn (all):", turn_num / total_dialog, file=f)
+        print("percentage of domains that satisfies the database constraints: %.3f}" % \
+              (1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)), file=f)
         f.close()
 
         reporter.report(complete_num/total_dialog, suc_num/total_dialog, np.mean(precision), np.mean(recall), np.mean(f1), tmp, turn_num / total_dialog)
-- 
GitLab