Skip to content
Snippets Groups Projects
Commit 806356d0 authored by newRuntieException's avatar newRuntieException Committed by zhuqi
Browse files

output goal analysis to file

parent c86fb226
No related branches found
No related tags found
No related merge requests found
...@@ -51,3 +51,7 @@ class Evaluator(object): ...@@ -51,3 +51,7 @@ class Evaluator(object):
judge if the domain (subtask) is successfully completed judge if the domain (subtask) is successfully completed
""" """
raise NotImplementedError raise NotImplementedError
def final_goal_analyze(self):
"""judge whether the final goal satisfies the database constraints"""
raise NotImplementedError
...@@ -396,6 +396,8 @@ class MultiWozEvaluator(Evaluator): ...@@ -396,6 +396,8 @@ class MultiWozEvaluator(Evaluator):
return match, mismatch return match, mismatch
def final_goal_analyze(self): def final_goal_analyze(self):
"""percentage of domains, in which the final goal satisfies the database constraints.
If there is no dialog action, returns 1."""
match, mismatch = self._final_goal_analyze() match, mismatch = self._final_goal_analyze()
if match == mismatch == 0: if match == mismatch == 0:
return 1 return 1
......
...@@ -213,6 +213,7 @@ def evaluate(dataset_name, model_name, load_path, calculate_reward=True): ...@@ -213,6 +213,7 @@ def evaluate(dataset_name, model_name, load_path, calculate_reward=True):
logging.info(f'task success: {task_succ}') logging.info(f'task success: {task_succ}')
logging.info(f'book rate: {sess.evaluator.book_rate()}') logging.info(f'book rate: {sess.evaluator.book_rate()}')
logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}') logging.info(f'inform precision/recall/f1: {sess.evaluator.inform_F1()}')
logging.info(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}")
logging.info('-'*50) logging.info('-'*50)
break break
else: else:
......
...@@ -50,6 +50,7 @@ class Analyzer: ...@@ -50,6 +50,7 @@ class Analyzer:
print('task success:', sess.evaluator.task_success()) print('task success:', sess.evaluator.task_success())
print('book rate:', sess.evaluator.book_rate()) print('book rate:', sess.evaluator.book_rate())
print('inform precision/recall/f1:', sess.evaluator.inform_F1()) print('inform precision/recall/f1:', sess.evaluator.inform_F1())
print(f"percentage of domains that satisfies the database constraints: {sess.evaluator.final_goal_analyze()}")
print('-' * 50) print('-' * 50)
print('final goal:') print('final goal:')
pprint(sess.evaluator.goal) pprint(sess.evaluator.goal)
...@@ -67,6 +68,8 @@ class Analyzer: ...@@ -67,6 +68,8 @@ class Analyzer:
complete_num = 0 complete_num = 0
turn_num = 0 turn_num = 0
turn_suc_num = 0 turn_suc_num = 0
num_domains = 0
num_domains_satisfying_constraints = 0
reporter = Reporter(model_name) reporter = Reporter(model_name)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -142,6 +145,7 @@ class Analyzer: ...@@ -142,6 +145,7 @@ class Analyzer:
task_complete = sess.user_agent.policy.policy.goal.task_complete() task_complete = sess.user_agent.policy.policy.goal.task_complete()
book_rate = sess.evaluator.book_rate() book_rate = sess.evaluator.book_rate()
stats = sess.evaluator.inform_F1() stats = sess.evaluator.inform_F1()
percentage = sess.evaluator.final_goal_analyze()
if task_success: if task_success:
suc_num += 1 suc_num += 1
turn_suc_num += step turn_suc_num += step
...@@ -153,6 +157,9 @@ class Analyzer: ...@@ -153,6 +157,9 @@ class Analyzer:
f1.append(stats[2]) f1.append(stats[2])
if book_rate is not None: if book_rate is not None:
match.append(book_rate) match.append(book_rate)
if len(sess.evaluator.goal) > 0:
num_domains += len(sess.evaluator.goal)
num_domains_satisfying_constraints += len(sess.evaluator.goal) * percentage
if (j+1) % 100 == 0: if (j+1) % 100 == 0:
logger.info("model name %s", model_name) logger.info("model name %s", model_name)
logger.info("dialogue %d", j+1) logger.info("dialogue %d", j+1)
...@@ -161,6 +168,8 @@ class Analyzer: ...@@ -161,6 +168,8 @@ class Analyzer:
logger.info('task success: %.3f', suc_num/(j+1)) logger.info('task success: %.3f', suc_num/(j+1))
logger.info('book rate: %.3f', np.mean(match)) logger.info('book rate: %.3f', np.mean(match))
logger.info('inform precision/recall/f1: %.3f %.3f %.3f', np.mean(precision), np.mean(recall), np.mean(f1)) logger.info('inform precision/recall/f1: %.3f %.3f %.3f', np.mean(precision), np.mean(recall), np.mean(f1))
logging.info("percentage of domains that satisfies the database constraints: %.3f}" % \
(1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)))
domain_set = [] domain_set = []
for da in sess.evaluator.usr_da_array: for da in sess.evaluator.usr_da_array:
if da.split('-')[0] != 'general' and da.split('-')[0] not in domain_set: if da.split('-')[0] != 'general' and da.split('-')[0] not in domain_set:
...@@ -195,6 +204,8 @@ class Analyzer: ...@@ -195,6 +204,8 @@ class Analyzer:
print('average book rate:', np.mean(match)) print('average book rate:', np.mean(match))
print("average turn (succ):", tmp) print("average turn (succ):", tmp)
print("average turn (all):", turn_num / total_dialog) print("average turn (all):", turn_num / total_dialog)
print("percentage of domains that satisfies the database constraints: %.3f}" % \
(1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)))
print("=" * 100) print("=" * 100)
print("complete number of dialogs/tot:", complete_num / total_dialog, file=f) print("complete number of dialogs/tot:", complete_num / total_dialog, file=f)
print("success number of dialogs/tot:", suc_num / total_dialog, file=f) print("success number of dialogs/tot:", suc_num / total_dialog, file=f)
...@@ -204,6 +215,8 @@ class Analyzer: ...@@ -204,6 +215,8 @@ class Analyzer:
print('average book rate:', np.mean(match), file=f) print('average book rate:', np.mean(match), file=f)
print("average turn (succ):", tmp, file=f) print("average turn (succ):", tmp, file=f)
print("average turn (all):", turn_num / total_dialog, file=f) print("average turn (all):", turn_num / total_dialog, file=f)
print("percentage of domains that satisfies the database constraints: %.3f}" % \
(1 if num_domains == 0 else (num_domains_satisfying_constraints / num_domains)), file=f)
f.close() f.close()
reporter.report(complete_num/total_dialog, suc_num/total_dialog, np.mean(precision), np.mean(recall), np.mean(f1), tmp, turn_num / total_dialog) reporter.report(complete_num/total_dialog, suc_num/total_dialog, np.mean(precision), np.mean(recall), np.mean(f1), tmp, turn_num / total_dialog)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment