From 06968a024ac97f0f6492d033d6217ee8a835abfb Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> Date: Wed, 25 Jan 2023 11:13:17 +0100 Subject: [PATCH] Save SetSUMBT belief state uncertainty statistics in dst info_dict to be saved during eval (#126) * Seperate test and train domains * Add progress bars in ontology embedder * Update custom_util.py * Fix custom_util things I broke * Github master * Save dialogue ids in prediction file * Fix bug in ontology enxtraction * Return dialogue ids in predictions file and fix bugs * Add setsumbt starting config loader * Add script to extract golden labels from dataset to match model predictions * Add more setsumbt configs * Add option to use local files only in transformers package * Update starting configurations for setsumbt * Github master * Update README.md * Update README.md * Update convlab/dialog_agent/agent.py * Revert custom_util.py * Update custom_util.py * Commit unverified chnages :(:(:(:( * Fix SetSUMBT bug resulting from new torch feature * Setsumbt bug fixes * Policy config refactor * Policy config refactor * small bug fix in memory with new config path * Setsumbt info dict Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de> Co-authored-by: Michael Heck <michael.heck@hhu.de> Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de> --- convlab/dst/setsumbt/tracker.py | 6 ++++-- convlab/util/custom_util.py | 8 ++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index e4033204..f56bbadc 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -27,7 +27,7 @@ class SetSUMBTTracker(DST): confidence_threshold='auto', return_belief_state_entropy: bool = False, return_belief_state_mutual_info: bool = False, - store_full_belief_state: bool = False): + store_full_belief_state: bool = True): """ Args: model_path: Model path or download URL @@ -326,7 +326,9 @@ class SetSUMBTTracker(DST): dialogue_state[dom][slot] = val if self.store_full_belief_state: - self.full_belief_state = belief_state + self.info_dict['belief_state_distributions'] = belief_state + if state_mutual_info is not None: + self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info # Obtain model output probabilities if self.return_confidence_scores: diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py index 5c6b0d33..3bc6550f 100644 --- a/convlab/util/custom_util.py +++ b/convlab/util/custom_util.py @@ -169,7 +169,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do single_domain_goals, allowed_domains) goals.append(goal[0]) - if conf['model']['process_num'] == 1: + if conf['model']['process_num'] == 1 or save_eval: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ select_acts, offer_acts, recommend_acts = evaluate(sess, @@ -330,7 +330,6 @@ def create_env(args, policy_sys): def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None, goals=None): - eval_save = {} turn_counter_dict = {} turn_counter = 0.0 @@ -426,10 +425,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False # print('length of dict ' + str(len(eval_save))) if save_flag: - # print("what are you doing") - save_file = open(os.path.join(save_path, 'evaluate_INFO.json'), 'w') - json.dump(eval_save, save_file, cls=NumpyEncoder) - save_file.close() + torch.save(eval_save, os.path.join(save_path, 'evaluate_INFO.pt')) # save dialogue_info and clear mem return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \ -- GitLab