From f15a6fa473632cbb9965689335ac5183bb0c5c66 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <vniekerk.carel@gmail.com>
Date: Wed, 12 Apr 2023 12:04:53 +0200
Subject: [PATCH] Auto save distributions and mutual info

---
 convlab/dst/setsumbt/tracker.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py
index 5126fd34..a4494ceb 100644
--- a/convlab/dst/setsumbt/tracker.py
+++ b/convlab/dst/setsumbt/tracker.py
@@ -263,7 +263,7 @@ class SetSUMBTTracker(DST):
             new_state['turn_pooled_representation'] = outputs.turn_pooled_representation.reshape(-1)
 
         self.state = new_state
-        self.info_dict['belief_state'] = copy.deepcopy(dict(new_state))
+        # self.info_dict['belief_state'] = copy.deepcopy(dict(new_state))
 
         return self.state
 
@@ -281,7 +281,8 @@ class SetSUMBTTracker(DST):
         with torch.no_grad():
             features['hidden_state'] = self.hidden_states
             features['get_turn_pooled_representation'] = self.return_turn_pooled_representation
-            features['calculate_state_mutual_info'] = self.return_belief_state_mutual_info
+            mutual_info = self.return_belief_state_mutual_info or self.store_full_belief_state
+            features['calculate_state_mutual_info'] = mutual_info
             outputs = self.model(**features)
             self.hidden_states = outputs.hidden_state
 
@@ -293,8 +294,7 @@ class SetSUMBTTracker(DST):
 
         if self.store_full_belief_state:
             self.info_dict['belief_state_distributions'] = outputs.belief_state
-            if state_mutual_info is not None:
-                self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information
+            self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information
 
         # Obtain model output probabilities
         if self.return_confidence_scores:
-- 
GitLab