diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index e40332048188b7f5a1f43e397896a8b6b201553d..f56bbadc2f4d8fdca102b2bbc996acb0ae5a4a58 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -27,7 +27,7 @@ class SetSUMBTTracker(DST): confidence_threshold='auto', return_belief_state_entropy: bool = False, return_belief_state_mutual_info: bool = False, - store_full_belief_state: bool = False): + store_full_belief_state: bool = True): """ Args: model_path: Model path or download URL @@ -326,7 +326,9 @@ class SetSUMBTTracker(DST): dialogue_state[dom][slot] = val if self.store_full_belief_state: - self.full_belief_state = belief_state + self.info_dict['belief_state_distributions'] = belief_state + if state_mutual_info is not None: + self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info # Obtain model output probabilities if self.return_confidence_scores: diff --git a/convlab/util/custom_util.py b/convlab/util/custom_util.py index 5c6b0d33f1755cf443521307024d6e17280f060d..3bc6550fe903b86f38240b0061f41f91d6326e3e 100644 --- a/convlab/util/custom_util.py +++ b/convlab/util/custom_util.py @@ -169,7 +169,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do single_domain_goals, allowed_domains) goals.append(goal[0]) - if conf['model']['process_num'] == 1: + if conf['model']['process_num'] == 1 or save_eval: complete_rate, success_rate, success_rate_strict, avg_return, turns, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \ select_acts, offer_acts, recommend_acts = evaluate(sess, @@ -330,7 +330,6 @@ def create_env(args, policy_sys): def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None, goals=None): - eval_save = {} turn_counter_dict = {} turn_counter = 0.0 @@ -426,10 +425,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False # print('length of dict ' + str(len(eval_save))) if save_flag: - # print("what are you doing") - save_file = open(os.path.join(save_path, 'evaluate_INFO.json'), 'w') - json.dump(eval_save, save_file, cls=NumpyEncoder) - save_file.close() + torch.save(eval_save, os.path.join(save_path, 'evaluate_INFO.pt')) # save dialogue_info and clear mem return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \