From 806356d0bb1f7aaad77cc676d2bfa4b1072ccb6d Mon Sep 17 00:00:00 2001 From: newRuntieException <wdz15@mails.tsinghua.edu.cn> Date: Tue, 25 Aug 2020 13:16:04 +0800 Subject: [PATCH] output goal analysis to file --- convlab2/evaluator/evaluator.py | 4 ++++ convlab2/evaluator/multiwoz_eval.py | 2 ++ convlab2/policy/evaluate.py | 1 + convlab2/util/analysis_tool/analyzer.py | 13 +++++++++++++ 4 files changed, 20 insertions(+) diff --git a/convlab2/evaluator/evaluator.py b/convlab2/evaluator/evaluator.py index 9dfaee0..89a37b3 100755 --- a/convlab2/evaluator/evaluator.py +++ b/convlab2/evaluator/evaluator.py @@ -51,3 +51,7 @@ class Evaluator(object): judge if the domain (subtask) is successfully completed """ raise NotImplementedError + + def final_goal_analyze(self): + """judge whether the final goal satisfies the database constraints""" + raise NotImplementedError diff --git a/convlab2/evaluator/multiwoz_eval.py b/convlab2/evaluator/multiwoz_eval.py index fd84621..78edc42 100755 --- a/convlab2/evaluator/multiwoz_eval.py +++ b/convlab2/evaluator/multiwoz_eval.py @@ -396,6 +396,8 @@ class MultiWozEvaluator(Evaluator): return match, mismatch def final_goal_analyze(self): + """percentage of domains, in which the final goal satisfies the database constraints. + If there is no dialog action, returns 1.""" match, mismatch = self._final_goal_analyze() if match == mismatch == 0: return 1 diff --git a/convlab2/policy/evaluate.py b/convlab2/policy/evaluate.py index 87c7b02..8085bf4 100755 --- a/convlab2/policy/evaluate.py +++ b/convlab2/policy/evaluate.py @@ -213,6 +213,7 @@ def evaluate(dataset_name, model_name, load_path, calculate_reward=True): logging.info(f'task success: {task_succ}') logging.info(f'book rate: {sess.evaluator.book_rate()}') logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}') + logging.info(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}") logging.info('-'*50) break else: diff --git a/convlab2/util/analysis_tool/analyzer.py b/convlab2/util/analysis_tool/analyzer.py index eeae31c..8cd3ebc 100755 --- a/convlab2/util/analysis_tool/analyzer.py +++ b/convlab2/util/analysis_tool/analyzer.py @@ -50,6 +50,7 @@ class Analyzer: print('task success:', sess.evaluator.task_success()) print('book rate:', sess.evaluator.book_rate()) print('inform precision/recall/f1:', sess.evaluator.inform_F1()) + print(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}") print('-' * 50) print('final goal:') pprint(sess.evaluator.goal) @@ -67,6 +68,8 @@ class Analyzer: complete_num = 0 turn_num = 0 turn_suc_num = 0 + num_domains = 0 + num_domains_satisfying_constraints = 0 reporter = Reporter(model_name) logger = logging.getLogger(__name__) @@ -142,6 +145,7 @@ class Analyzer: task_complete = sess.user_agent.policy.policy.goal.task_complete() book_rate = sess.evaluator.book_rate() stats = sess.evaluator.inform_F1() + percentage = sess.evaluator.final_goal_analyze() if task_success: suc_num += 1 turn_suc_num += step @@ -153,6 +157,9 @@ class Analyzer: f1.append(stats[2]) if book_rate is not None: match.append(book_rate) + if len(sess.evaluator.goal) > 0: + num_domains += len(sess.evaluator.goal) + num_domains_satisfying_constraints += len(sess.evaluator.goal) * percentage if (j+1) % 100 == 0: logger.info("model name %s", model_name) logger.info("dialogue %d", j+1) @@ -161,6 +168,8 @@ class Analyzer: logger.info('task success: %.3f', suc_num/(j+1)) logger.info('book rate: %.3f', np.mean(match)) logger.info('inform precision/recall/f1: %.3f %.3f %.3f', np.mean(precision), np.mean(recall), np.mean(f1)) + logging.info("percentage of domains that satisfies the database constraints: %.3f}" % \ + (1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains))) domain_set = [] for da in sess.evaluator.usr_da_array: if da.split('-')[0] != 'general' and da.split('-')[0] not in domain_set: @@ -195,6 +204,8 @@ class Analyzer: print('average book rate:', np.mean(match)) print("average turn (succ):", tmp) print("average turn (all):", turn_num / total_dialog) + print("percentage of domains that satisfies the database constraints: %.3f}" % \ + (1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains))) print("=" * 100) print("complete number of dialogs/tot:", complete_num / total_dialog, file=f) print("success number of dialogs/tot:", suc_num / total_dialog, file=f) @@ -204,6 +215,8 @@ class Analyzer: print('average book rate:', np.mean(match), file=f) print("average turn (succ):", tmp, file=f) print("average turn (all):", turn_num / total_dialog, file=f) + print("percentage of domains that satisfies the database constraints: %.3f}" % \ + (1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)), file=f) f.close() reporter.report(complete_num/total_dialog, suc_num/total_dialog, np.mean(precision), np.mean(recall), np.mean(f1), tmp, turn_num / total_dialog) -- GitLab