Skip to content
Snippets Groups Projects
Unverified Commit 06968a02 authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Save SetSUMBT belief state uncertainty statistics in dst info_dict to be saved during eval (#126)


* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

* Setsumbt info dict

Co-authored-by: default avatarCarel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: default avatarMichael Heck <michael.heck@hhu.de>
Co-authored-by: default avatarChristian Geishauser <christian.geishauser@hhu.de>
parent 182d847b
No related branches found
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ class SetSUMBTTracker(DST): ...@@ -27,7 +27,7 @@ class SetSUMBTTracker(DST):
confidence_threshold='auto', confidence_threshold='auto',
return_belief_state_entropy: bool = False, return_belief_state_entropy: bool = False,
return_belief_state_mutual_info: bool = False, return_belief_state_mutual_info: bool = False,
store_full_belief_state: bool = False): store_full_belief_state: bool = True):
""" """
Args: Args:
model_path: Model path or download URL model_path: Model path or download URL
...@@ -326,7 +326,9 @@ class SetSUMBTTracker(DST): ...@@ -326,7 +326,9 @@ class SetSUMBTTracker(DST):
dialogue_state[dom][slot] = val dialogue_state[dom][slot] = val
if self.store_full_belief_state: if self.store_full_belief_state:
self.full_belief_state = belief_state self.info_dict['belief_state_distributions'] = belief_state
if state_mutual_info is not None:
self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info
# Obtain model output probabilities # Obtain model output probabilities
if self.return_confidence_scores: if self.return_confidence_scores:
......
...@@ -169,7 +169,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do ...@@ -169,7 +169,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
single_domain_goals, allowed_domains) single_domain_goals, allowed_domains)
goals.append(goal[0]) goals.append(goal[0])
if conf['model']['process_num'] == 1: if conf['model']['process_num'] == 1 or save_eval:
complete_rate, success_rate, success_rate_strict, avg_return, turns, \ complete_rate, success_rate, success_rate_strict, avg_return, turns, \
avg_actions, task_success, book_acts, inform_acts, request_acts, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \
select_acts, offer_acts, recommend_acts = evaluate(sess, select_acts, offer_acts, recommend_acts = evaluate(sess,
...@@ -330,7 +330,6 @@ def create_env(args, policy_sys): ...@@ -330,7 +330,6 @@ def create_env(args, policy_sys):
def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None, goals=None): def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False, save_path=None, goals=None):
eval_save = {} eval_save = {}
turn_counter_dict = {} turn_counter_dict = {}
turn_counter = 0.0 turn_counter = 0.0
...@@ -426,10 +425,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False ...@@ -426,10 +425,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
# print('length of dict ' + str(len(eval_save))) # print('length of dict ' + str(len(eval_save)))
if save_flag: if save_flag:
# print("what are you doing") torch.save(eval_save, os.path.join(save_path, 'evaluate_INFO.pt'))
save_file = open(os.path.join(save_path, 'evaluate_INFO.json'), 'w')
json.dump(eval_save, save_file, cls=NumpyEncoder)
save_file.close()
# save dialogue_info and clear mem # save dialogue_info and clear mem
return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \ return task_success['All_user_sim'], task_success['All_evaluator'], task_success['All_evaluator_strict'], \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment