Skip to content
Snippets Groups Projects
Commit f15a6fa4 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Auto save distributions and mutual info

parent 8a231167
No related branches found
No related tags found
No related merge requests found
...@@ -263,7 +263,7 @@ class SetSUMBTTracker(DST): ...@@ -263,7 +263,7 @@ class SetSUMBTTracker(DST):
new_state['turn_pooled_representation'] = outputs.turn_pooled_representation.reshape(-1) new_state['turn_pooled_representation'] = outputs.turn_pooled_representation.reshape(-1)
self.state = new_state self.state = new_state
self.info_dict['belief_state'] = copy.deepcopy(dict(new_state)) # self.info_dict['belief_state'] = copy.deepcopy(dict(new_state))
return self.state return self.state
...@@ -281,7 +281,8 @@ class SetSUMBTTracker(DST): ...@@ -281,7 +281,8 @@ class SetSUMBTTracker(DST):
with torch.no_grad(): with torch.no_grad():
features['hidden_state'] = self.hidden_states features['hidden_state'] = self.hidden_states
features['get_turn_pooled_representation'] = self.return_turn_pooled_representation features['get_turn_pooled_representation'] = self.return_turn_pooled_representation
features['calculate_state_mutual_info'] = self.return_belief_state_mutual_info mutual_info = self.return_belief_state_mutual_info or self.store_full_belief_state
features['calculate_state_mutual_info'] = mutual_info
outputs = self.model(**features) outputs = self.model(**features)
self.hidden_states = outputs.hidden_state self.hidden_states = outputs.hidden_state
...@@ -293,7 +294,6 @@ class SetSUMBTTracker(DST): ...@@ -293,7 +294,6 @@ class SetSUMBTTracker(DST):
if self.store_full_belief_state: if self.store_full_belief_state:
self.info_dict['belief_state_distributions'] = outputs.belief_state self.info_dict['belief_state_distributions'] = outputs.belief_state
if state_mutual_info is not None:
self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information
# Obtain model output probabilities # Obtain model output probabilities
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment