Skip to content
Snippets Groups Projects
Select Git revision
  • e4d716f52c9df83a47c63c2d70b0ee7b9d5ace45
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

setsumbt.py

Blame
  • user avatar
    Carel van Niekerk authored
    e4d716f5
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    setsumbt.py 27.71 KiB
    # -*- coding: utf-8 -*-
    # Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
    # Authors: Carel van Niekerk (niekerk@hhu.de)
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    """SetSUMBT Prediction Head"""
    
    import torch
    from torch.autograd import Variable
    from torch.nn import (Module, MultiheadAttention, GRU, LSTM, Linear, LayerNorm, Dropout,
                          CosineSimilarity, CrossEntropyLoss, PairwiseDistance,
                          Sequential, ReLU, Conv1d, GELU, BCEWithLogitsLoss)
    from torch.nn.init import (xavier_normal_, constant_)
    
    from convlab.dst.setsumbt.loss import (BayesianMatchingLoss, BinaryBayesianMatchingLoss,
                                           KLDistillationLoss, BinaryKLDistillationLoss,
                                           LabelSmoothingLoss, BinaryLabelSmoothingLoss,
                                           RKLDirichletMediatorLoss, BinaryRKLDirichletMediatorLoss)
    
    
    class SlotUtteranceMatching(Module):
        """Slot Utterance matching attention based information extractor"""
    
        def __init__(self, hidden_size: int = 768, attention_heads: int = 12):
            """
            Args:
                hidden_size (int): Dimension of token embeddings
                attention_heads (int): Number of attention heads to use in attention module
            """
            super(SlotUtteranceMatching, self).__init__()
    
            self.attention = MultiheadAttention(hidden_size, attention_heads)
    
        def forward(self,
                    turn_embeddings: torch.Tensor,
                    attention_mask: torch.Tensor,
                    slot_embeddings: torch.Tensor) -> torch.Tensor:
            """
            Args:
                turn_embeddings: Embeddings for each token in each turn [n_turns, turn_length, hidden_size]
                attention_mask: Padding mask for each turn [n_turns, turn_length, hidden_size]
                slot_embeddings: Embeddings for each token in the slot descriptions
    
            Returns:
                hidden: Information extracted from turn related to slot descriptions
            """
            turn_embeddings = turn_embeddings.transpose(0, 1)
    
            key_padding_mask = (attention_mask[:, :, 0] == 0.0)
            key_padding_mask[key_padding_mask[:, 0], :] = False
    
            hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
                                       key_padding_mask=key_padding_mask)
    
            attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
            hidden = hidden * attention_mask
    
            return hidden
    
    
    class RecurrentNeuralBeliefTracker(Module):
        """Recurrent latent neural belief tracking module"""
    
        def __init__(self,
                     nbt_type: str = 'gru',
                     rnn_zero_init: bool = False,
                     input_size: int = 768,
                     hidden_size: int = 300,
                     hidden_layers: int = 1,
                     dropout_rate: float = 0.3):
            """
            Args:
                nbt_type: Type of recurrent neural network (gru/lstm)
                rnn_zero_init: Use zero initialised state for the RNN
                input_size: Embedding size of the inputs
                hidden_size: Hidden size of the RNN
                hidden_layers: Number of RNN Layers
                dropout_rate: Dropout rate
            """
            super(RecurrentNeuralBeliefTracker, self).__init__()
    
            if rnn_zero_init:
                self.belief_init = Sequential(Linear(input_size, hidden_size), ReLU(), Dropout(dropout_rate))
            else:
                self.belief_init = None
    
            self.nbt_type = nbt_type
            self.hidden_layers = hidden_layers
            self.hidden_size = hidden_size
            if nbt_type == 'gru':
                self.nbt = GRU(input_size=input_size,
                               hidden_size=hidden_size,
                               num_layers=hidden_layers,
                               dropout=0.0 if hidden_layers == 1 else dropout_rate,
                               batch_first=True)
            elif nbt_type == 'lstm':
                self.nbt = LSTM(input_size=input_size,
                                hidden_size=hidden_size,
                                num_layers=hidden_layers,
                                dropout=0.0 if hidden_layers == 1 else dropout_rate,
                                batch_first=True)
            else:
                raise NameError('Not Implemented')
    
            # Initialise Parameters
            xavier_normal_(self.nbt.weight_ih_l0)
            xavier_normal_(self.nbt.weight_hh_l0)
            constant_(self.nbt.bias_ih_l0, 0.0)
            constant_(self.nbt.bias_hh_l0, 0.0)
    
            # Intermediate feature mapping and layer normalisation
            self.intermediate = Linear(hidden_size, input_size)
            self.layer_norm = LayerNorm(input_size)
            self.dropout = Dropout(dropout_rate)
    
        def forward(self, inputs: torch.Tensor, hidden_state: torch.Tensor = None) -> torch.Tensor:
            """
            Args:
                inputs: Latent turn level information
                hidden_state: Latent internal belief state
    
            Returns:
                belief_embedding: Belief state embeddings
                context: Latent internal belief state
            """
            self.nbt.flatten_parameters()
            if hidden_state is None:
                if self.belief_init is None:
                    context = torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device)
                else:
                    context = self.belief_init(inputs[:, 0, :]).unsqueeze(0).repeat((self.hidden_layers, 1, 1))
                if self.nbt_type == "lstm":
                    context = (context, torch.zeros(self.hidden_layers, inputs.size(0), self.hidden_size).to(inputs.device))
            else:
                context = hidden_state.to(inputs.device)
    
            # [batch_size, dialogue_size, nbt_hidden_size]
            belief_embedding, context = self.nbt(inputs, context)
    
            # Normalisation and regularisation
            belief_embedding = self.layer_norm(self.intermediate(belief_embedding))
            belief_embedding = self.dropout(belief_embedding)
    
            return belief_embedding, context
    
    
    class SetPooler(Module):
        """Token set pooler"""
    
        def __init__(self, pooling_strategy: str = 'cnn', hidden_size: int = 768):
            """
            Args:
                pooling_strategy: Type of set pooler (cnn/dan/mean)
                hidden_size: Token embedding size
            """
            super(SetPooler, self).__init__()
    
            self.pooling_strategy = pooling_strategy
            if pooling_strategy == 'cnn':
                self.cnn_filter_size = 3
                self.pooler = Conv1d(hidden_size, hidden_size, self.cnn_filter_size)
            elif pooling_strategy == 'dan':
                self.pooler = Sequential(Linear(hidden_size, hidden_size), GELU(), Linear(2 * hidden_size, hidden_size))
    
        def forward(self, inputs, attention_mask):
            """
            Args:
                inputs: Token set embeddings
                attention_mask: Padding mask for the set of tokens
    
            Returns:
    
            """
            if self.pooling_strategy == "mean":
                hidden = inputs.sum(1) / attention_mask.sum(1)
            elif self.pooling_strategy == "cnn":
                hidden = self.pooler(inputs.transpose(1, 2)).mean(-1)
            elif self.pooling_strategy == 'dan':
                hidden = inputs.sum(1) / torch.sqrt(torch.tensor(attention_mask.sum(1)))
                hidden = self.pooler(hidden)
    
            return hidden
    
    
    class SetSUMBTHead(Module):
        """SetSUMBT Prediction Head for Language Models"""
    
        def __init__(self, config):
            """
            Args:
                config (configuration): Model configuration class
            """
            super(SetSUMBTHead, self).__init__()
            self.config = config
            # Slot Utterance matching attention
            self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads)
    
            # Latent context tracker
            self.nbt = RecurrentNeuralBeliefTracker(config.nbt_type, config.rnn_zero_init, config.hidden_size,
                                                    config.nbt_hidden_size, config.nbt_layers, config.dropout_rate)
    
            # Set pooler for set similarity model
            if self.config.set_similarity:
                self.set_pooler = SetPooler(config.set_pooling, config.hidden_size)
    
            # Model ontology placeholders
            self.slot_embeddings = Variable(torch.zeros(0), requires_grad=False)
            self.slot_ids = dict()
            self.requestable_slot_ids = dict()
            self.informable_slot_ids = dict()
            self.domain_ids = dict()
    
            # Matching network similarity measure
            if config.distance_measure == 'cosine':
                self.distance = CosineSimilarity(dim=-1, eps=1e-8)
            elif config.distance_measure == 'euclidean':
                self.distance = PairwiseDistance(p=2.0, eps=1e-6, keepdim=False)
            else:
                raise NameError('NotImplemented')
    
            # User goal prediction loss function
            if config.loss_function == 'crossentropy':
                self.loss = CrossEntropyLoss(ignore_index=-1)
            elif config.loss_function == 'bayesianmatching':
                self.loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
            elif config.loss_function == 'labelsmoothing':
                self.loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
            elif config.loss_function == 'distillation':
                self.loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
                self.temp = 1.0
            elif config.loss_function == 'distribution_distillation':
                self.loss = RKLDirichletMediatorLoss(ignore_index=-1)
            else:
                raise NameError('NotImplemented')
    
            # Intent and domain prediction heads
            if config.predict_actions:
                self.request_gate = Linear(config.hidden_size, 1)
                self.general_act_gate = Linear(config.hidden_size, 3)
                self.active_domain_gate = Linear(config.hidden_size, 1)
    
                # Intent and domain loss function
                self.request_weight = float(self.config.user_request_loss_weight)
                self.general_act_weight = float(self.config.user_general_act_loss_weight)
                self.active_domain_weight = float(self.config.active_domain_loss_weight)
                if config.loss_function == 'crossentropy':
                    self.request_loss = BCEWithLogitsLoss()
                    self.general_act_loss = CrossEntropyLoss(ignore_index=-1)
                    self.active_domain_loss = BCEWithLogitsLoss()
                elif config.loss_function == 'labelsmoothing':
                    self.request_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
                    self.general_act_loss = LabelSmoothingLoss(ignore_index=-1, label_smoothing=config.label_smoothing)
                    self.active_domain_loss = BinaryLabelSmoothingLoss(label_smoothing=config.label_smoothing)
                elif config.loss_function == 'bayesianmatching':
                    self.request_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
                    self.general_act_loss = BayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
                    self.active_domain_loss = BinaryBayesianMatchingLoss(ignore_index=-1, lamb=config.kl_scaling_factor)
                elif config.loss_function == 'distillation':
                    self.request_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
                    self.general_act_loss = KLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
                    self.active_domain_loss = BinaryKLDistillationLoss(ignore_index=-1, lamb=config.ensemble_smoothing)
                elif config.loss_function == 'distribution_distillation':
                    self.request_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
                    self.general_act_loss = RKLDirichletMediatorLoss(ignore_index=-1)
                    self.active_domain_loss = BinaryRKLDirichletMediatorLoss(ignore_index=-1)
    
        def add_slot_candidates(self, slot_candidates: tuple):
            """
            Add slots to the model ontology, the tuples should contain the slot embedding, informable value embeddings
            and a request indicator, if the informable value embeddings is None the slot is not informable and if
            the request indicator is false the slot is not requestable.
    
            Args:
                slot_candidates: Tuple containing slot embedding, informable value embeddings and a request indicator
            """
            if self.slot_embeddings.size(0) != 0:
                embeddings = self.slot_embeddings.detach()
            else:
                embeddings = torch.zeros(0)
    
            for slot in slot_candidates:
                if slot in self.slot_ids:
                    index = self.slot_ids[slot]
                    embeddings[index, :] = slot_candidates[slot][0]
                else:
                    index = embeddings.size(0)
                    emb = slot_candidates[slot][0].unsqueeze(0).to(embeddings.device)
                    embeddings = torch.cat((embeddings, emb), 0)
                    self.slot_ids[slot] = index
                    setattr(self, slot + '_value_embeddings', Variable(torch.zeros(0), requires_grad=False))
                # Add slot to relevant requestable and informable slot lists
                if slot_candidates[slot][2]:
                    self.requestable_slot_ids[slot] = index
                if slot_candidates[slot][1] is not None:
                    self.informable_slot_ids[slot] = index
    
                domain = slot.split('-', 1)[0]
                if domain not in self.domain_ids:
                    self.domain_ids[domain] = []
                self.domain_ids[domain].append(index)
                self.domain_ids[domain] = list(set(self.domain_ids[domain]))
    
            self.slot_embeddings = Variable(embeddings, requires_grad=False)
    
        def add_value_candidates(self, slot: str, value_candidates: torch.Tensor, replace: bool = False):
            """
            Add value candidates for a slot
    
            Args:
                slot: Slot name
                value_candidates: Value candidate embeddings
                replace: If true existing value candidates are replaced
            """
            embeddings = getattr(self, slot + '_value_embeddings')
    
            if embeddings.size(0) == 0 or replace:
                embeddings = value_candidates
            else:
                embeddings = torch.cat((embeddings, value_candidates.to(embeddings.device)), 0)
    
            setattr(self, slot + '_value_embeddings', embeddings)
    
        def forward(self,
                    turn_embeddings: torch.Tensor,
                    turn_pooled_representation: torch.Tensor,
                    attention_mask: torch.Tensor,
                    batch_size: int,
                    dialogue_size: int,
                    hidden_state: torch.Tensor = None,
                    state_labels: torch.Tensor = None,
                    request_labels: torch.Tensor = None,
                    active_domain_labels: torch.Tensor = None,
                    general_act_labels: torch.Tensor = None,
                    calculate_state_mutual_info: bool = False):
            """
            Args:
                turn_embeddings: Token embeddings in the current turn
                turn_pooled_representation: Pooled representation of the current dialogue turn
                attention_mask: Padding mask for the current dialogue turn
                batch_size: Number of dialogues in the batch
                dialogue_size: Number of turns in each dialogue
                hidden_state: Latent internal dialogue belief state
                state_labels: Dialogue state labels
                request_labels: User request action labels
                active_domain_labels: Current active domain labels
                general_act_labels: General user action labels
                calculate_state_mutual_info: Return mutual information in the dialogue state
    
            Returns:
                out: Tuple containing loss, predictive distributions, model statistics and state mutual information
            """
            hidden_size = turn_embeddings.size(-1)
            # Initialise loss
            loss = 0.0
    
            # General Action predictions
            general_act_probs = None
            if self.config.predict_actions:
                # General action prediction
                general_act_logits = self.general_act_gate(turn_pooled_representation.reshape(batch_size * dialogue_size,
                                                                                              hidden_size))
    
                # Compute loss for general action predictions (weighted loss)
                if general_act_labels is not None:
                    if self.config.loss_function == 'distillation':
                        general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-1))
                        loss += self.general_act_loss(general_act_logits, general_act_labels,
                                                      self.temp) * self.general_act_weight
                    elif self.config.loss_function == 'distribution_distillation':
                        general_act_labels = general_act_labels.reshape(-1, general_act_labels.size(-2),
                                                                        general_act_labels.size(-1))
                        loss += self.general_act_loss(general_act_logits, general_act_labels)[0] * self.general_act_weight
                    else:
                        general_act_labels = general_act_labels.reshape(-1)
                        loss += self.general_act_loss(general_act_logits, general_act_labels) * self.general_act_weight
    
                # Compute general action probabilities
                general_act_probs = torch.softmax(general_act_logits, -1).reshape(batch_size, dialogue_size, -1)
    
            # Slot utterance matching
            num_slots = self.slot_embeddings.size(0)
            slot_embeddings = self.slot_embeddings.reshape(-1, hidden_size)
            slot_embeddings = slot_embeddings.unsqueeze(1).repeat((1, batch_size * dialogue_size, 1))
            slot_embeddings = slot_embeddings.to(turn_embeddings.device)
    
            if self.config.set_similarity:
                # Slot mask shape [num_slots * slot_len, batch_size * dialogue_size, 768]
                slot_mask = (slot_embeddings != 0.0).float()
    
            hidden = self.slot_utterance_matching(turn_embeddings, attention_mask, slot_embeddings)
    
            if self.config.set_similarity:
                hidden = hidden * slot_mask
            # Hidden layer shape [num_dials, num_slots, num_turns, 768]
            hidden = hidden.transpose(0, 1).reshape(batch_size, dialogue_size, slot_embeddings.size(0), -1).transpose(1, 2)
    
            # Latent context tracking
            # [batch_size * num_slots, dialogue_size, 768]
            hidden = hidden.reshape(batch_size * slot_embeddings.size(0), dialogue_size, -1)
            belief_embedding, hidden_state = self.nbt(hidden, hidden_state)
    
            belief_embedding = belief_embedding.reshape(batch_size, slot_embeddings.size(0),
                                                        dialogue_size, -1).transpose(1, 2)
            if self.config.set_similarity:
                belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
                                                            self.config.hidden_size)
            # [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
    
            # Pooling of the set of latent context representation
            if self.config.set_similarity:
                slot_mask = slot_mask.transpose(0, 1).reshape(batch_size, dialogue_size, num_slots, -1, hidden_size)
                belief_embedding = belief_embedding * slot_mask
    
                belief_embedding = self.set_pooler(belief_embedding.reshape(-1, slot_mask.size(-2), hidden_size),
                                                   slot_mask.reshape(-1, slot_mask.size(-2), hidden_size))
                belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
    
            # Perform classification
            # Get padded batch, dialogue idx pairs
            batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0)
            
            if self.config.predict_actions:
                # User request prediction
                request_probs = dict()
                for slot, slot_id in self.requestable_slot_ids.items():
                    request_logits = self.request_gate(belief_embedding[:, :, slot_id, :])
    
                    # Store output probabilities
                    request_logits = request_logits.reshape(batch_size, dialogue_size)
                    # Set request scores to 0.0 for padded turns
                    request_logits[batches, dialogues] = 0.0
                    request_probs[slot] = torch.sigmoid(request_logits)
    
                    if request_labels is not None:
                        # Compute request gate loss
                        request_logits = request_logits.reshape(-1)
                        if self.config.loss_function == 'distillation':
                            loss += self.request_loss(request_logits, request_labels[slot].reshape(-1),
                                                      self.temp) * self.request_weight
                        elif self.config.loss_function == 'distribution_distillation':
                            loss += self.request_loss(request_logits, request_labels[slot])[0] * self.request_weight
                        else:
                            labs = request_labels[slot].reshape(-1)
                            request_logits = request_logits[labs != -1]
                            labs = labs[labs != -1].float()
                            loss += self.request_loss(request_logits, labs) * self.request_weight
    
                # Active domain prediction
                active_domain_probs = dict()
                for domain, slot_ids in self.domain_ids.items():
                    belief = belief_embedding[:, :, slot_ids, :]
                    if len(slot_ids) > 1:
                        # SqrtN reduction across all slots within a domain
                        belief = belief.sum(2) / ((belief != 0.0).float().sum(2) ** 0.5)
                    active_domain_logits = self.active_domain_gate(belief)
    
                    # Store output probabilities
                    active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size)
                    active_domain_logits[batches, dialogues] = 0.0
                    active_domain_probs[domain] = torch.sigmoid(active_domain_logits)
    
                    if active_domain_labels is not None and domain in active_domain_labels:
                        # Compute domain prediction loss
                        active_domain_logits = active_domain_logits.reshape(-1)
                        if self.config.loss_function == 'distillation':
                            loss += self.active_domain_loss(active_domain_logits, active_domain_labels[domain].reshape(-1),
                                                            self.temp) * self.active_domain_weight
                        elif self.config.loss_function == 'distribution_distillation':
                            loss += self.active_domain_loss(active_domain_logits,
                                                            active_domain_labels[domain])[0] * self.active_domain_weight
                        else:
                            labs = active_domain_labels[domain].reshape(-1)
                            active_domain_logits = active_domain_logits[labs != -1]
                            labs = labs[labs != -1].float()
                            loss += self.active_domain_loss(active_domain_logits, labs) * self.active_domain_weight
            else:
                request_probs, active_domain_probs = None, None
    
            # Dialogue state predictions
            belief_state_probs = dict()
            belief_state_mutual_info = dict()
            belief_state_stats = dict()
            for slot, slot_id in self.informable_slot_ids.items():
                # Get slot belief embedding and value candidates
                candidate_embeddings = getattr(self, slot + '_value_embeddings').to(turn_embeddings.device)
                belief = belief_embedding[:, :, slot_id, :]
                slot_size = candidate_embeddings.size(0)
    
                belief = belief.unsqueeze(2).repeat((1, 1, slot_size, 1))
                belief = belief.reshape(-1, self.config.hidden_size)
    
                if self.config.set_similarity:
                    candidate_embeddings = self.set_pooler(candidate_embeddings, (candidate_embeddings != 0.0).float())
                candidate_embeddings = candidate_embeddings.unsqueeze(0).unsqueeze(0).repeat((batch_size,
                                                                                              dialogue_size, 1, 1))
                candidate_embeddings = candidate_embeddings.reshape(-1, self.config.hidden_size)
    
                # Score value candidates
                if self.config.distance_measure == 'cosine':
                    logits = self.distance(belief, candidate_embeddings)
                    # *27 here rescales the cosine similarity for better learning
                    logits = logits.reshape(batch_size * dialogue_size, -1) * 27.0
                elif self.config.distance_measure == 'euclidean':
                    logits = -1.0 * self.distance(belief, candidate_embeddings)
                    logits = logits.reshape(batch_size * dialogue_size, -1)
    
                # Calculate belief state
                probs_ = torch.softmax(logits.reshape(batch_size, dialogue_size, -1), -1)
    
                # Compute knowledge uncertainty in the beleif states
                if calculate_state_mutual_info and self.config.loss_function == 'distribution_distillation':
                    belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size)
    
                # Set padded turn probabilities to zero
                probs_[batches, dialogues, :] = 0.0
                belief_state_probs[slot] = probs_
    
                # Calculate belief state loss
                if state_labels is not None and slot in state_labels:
                    if self.config.loss_function == 'bayesianmatching':
                        prior = torch.ones(logits.size(-1)).float().to(logits.device)
                        prior = prior * self.config.prior_constant
                        prior = prior.unsqueeze(0).repeat((logits.size(0), 1))
    
                        loss += self.loss(logits, state_labels[slot].reshape(-1), prior=prior)
                    elif self.config.loss_function == 'distillation':
                        labels = state_labels[slot]
                        labels = labels.reshape(-1, labels.size(-1))
                        loss += self.loss(logits, labels, self.temp)
                    elif self.config.loss_function == 'distribution_distillation':
                        labels = state_labels[slot]
                        labels = labels.reshape(-1, labels.size(-2), labels.size(-1))
                        loss_, model_stats, ensemble_stats = self.loss(logits, labels)
                        loss += loss_
    
                        # Calculate stats regarding model precisions
                        precision = model_stats['precision']
                        ensemble_precision = ensemble_stats['precision']
                        belief_state_stats[slot] = {'model_precision_min': precision.min(),
                                                    'model_precision_max': precision.max(),
                                                    'model_precision_mean': precision.mean(),
                                                    'ensemble_precision_min': ensemble_precision.min(),
                                                    'ensemble_precision_max': ensemble_precision.max(),
                                                    'ensemble_precision_mean': ensemble_precision.mean()}
                    else:
                        loss += self.loss(logits, state_labels[slot].reshape(-1))
    
            # Return model outputs
            out = belief_state_probs, request_probs, active_domain_probs, general_act_probs, hidden_state
            if state_labels is not None or request_labels is not None:
                out = (loss,) + out + (belief_state_stats,)
            if calculate_state_mutual_info:
                out = out + (belief_state_mutual_info,)
            return out