diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index 5126fd3439c4dd77a2f27720bdd68f7c0f7e947a..a4494ceb94c477bde662dcd7719631ff827a6b53 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -263,7 +263,7 @@ class SetSUMBTTracker(DST): new_state['turn_pooled_representation'] = outputs.turn_pooled_representation.reshape(-1) self.state = new_state - self.info_dict['belief_state'] = copy.deepcopy(dict(new_state)) + # self.info_dict['belief_state'] = copy.deepcopy(dict(new_state)) return self.state @@ -281,7 +281,8 @@ class SetSUMBTTracker(DST): with torch.no_grad(): features['hidden_state'] = self.hidden_states features['get_turn_pooled_representation'] = self.return_turn_pooled_representation - features['calculate_state_mutual_info'] = self.return_belief_state_mutual_info + mutual_info = self.return_belief_state_mutual_info or self.store_full_belief_state + features['calculate_state_mutual_info'] = mutual_info outputs = self.model(**features) self.hidden_states = outputs.hidden_state @@ -293,8 +294,7 @@ class SetSUMBTTracker(DST): if self.store_full_belief_state: self.info_dict['belief_state_distributions'] = outputs.belief_state - if state_mutual_info is not None: - self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information + self.info_dict['belief_state_knowledge_uncertainty'] = outputs.belief_state_mutual_information # Obtain model output probabilities if self.return_confidence_scores: