Skip to content
Snippets Groups Projects
Commit 96bc68bc authored by Carel van Niekerk's avatar Carel van Niekerk :desktop:
Browse files

Fix bugs in training code

parent cf2d6ad7
No related branches found
No related tags found
No related merge requests found
......@@ -34,9 +34,9 @@ class BertSetSUMBT(BertPreTrainedModel):
for p in self.bert.parameters():
p.requires_grad = False
self.prediction_head = SetSUMBTHead(config)
self.add_slot_candidates = self.prediction_head.add_slot_candidates
self.add_value_candidates = self.prediction_head.add_value_candidates
self.setsumbt = SetSUMBTHead(config)
self.add_slot_candidates = self.setsumbt.add_slot_candidates
self.add_value_candidates = self.setsumbt.add_value_candidates
def forward(self,
input_ids: torch.Tensor,
......@@ -80,10 +80,10 @@ class BertSetSUMBT(BertPreTrainedModel):
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
if get_turn_pooled_representation:
return self.prediction_head(turn_embeddings, bert_output.pooler_output, attention_mask,
return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask,
batch_size, dialogue_size, hidden_state, state_labels,
request_labels, active_domain_labels, general_act_labels,
calculate_state_mutual_info) + (bert_output.pooler_output,)
return self.prediction_head(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
return self.setsumbt(turn_embeddings, bert_output.pooler_output, attention_mask, batch_size,
dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
general_act_labels, calculate_state_mutual_info)
......@@ -38,9 +38,9 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
for p in self.roberta.parameters():
p.requires_grad = False
self.prediction_head = SetSUMBTHead(config)
self.add_slot_candidates = self.prediction_head.add_slot_candidates
self.add_value_candidates = self.prediction_head.add_value_candidates
self.setsumbt = SetSUMBTHead(config)
self.add_slot_candidates = self.setsumbt.add_slot_candidates
self.add_value_candidates = self.setsumbt.add_value_candidates
def forward(self,
input_ids: torch.Tensor,
......@@ -86,10 +86,10 @@ class RobertaSetSUMBT(RobertaPreTrainedModel):
turn_embeddings = turn_embeddings.reshape(batch_size * dialogue_size, turn_size, -1)
if get_turn_pooled_representation:
return self.prediction_head(turn_embeddings, roberta_output.pooler_output, attention_mask,
return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask,
batch_size, dialogue_size, hidden_state, state_labels,
request_labels, active_domain_labels, general_act_labels,
calculate_state_mutual_info) + (roberta_output.pooler_output,)
return self.prediction_head(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
return self.setsumbt(turn_embeddings, roberta_output.pooler_output, attention_mask, batch_size,
dialogue_size, hidden_state, state_labels, request_labels, active_domain_labels,
general_act_labels, calculate_state_mutual_info)
......@@ -59,9 +59,7 @@ class SlotUtteranceMatching(Module):
key_padding_mask = (attention_mask[:, :, 0] == 0.0)
key_padding_mask[key_padding_mask[:, 0], :] = False
hidden, _ = self.slot_attention(query=slot_embeddings,
key=turn_embeddings,
value=turn_embeddings,
hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
key_padding_mask=key_padding_mask)
attention_mask = attention_mask[:, 0, :].unsqueeze(0).repeat((slot_embeddings.size(0), 1, 1))
......@@ -203,6 +201,7 @@ class SetSUMBTHead(Module):
config (configuration): Model configuration class
"""
super(SetSUMBTHead, self).__init__()
self.config = config
# Slot Utterance matching attention
self.slot_utterance_matching = SlotUtteranceMatching(config.hidden_size, config.slot_attention_heads)
......@@ -414,7 +413,7 @@ class SetSUMBTHead(Module):
dialogue_size, -1).transpose(1, 2)
if self.config.set_similarity:
belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1,
self.config.nbt_hidden_size)
self.config.hidden_size)
# [batch_size, dialogue_size, num_slots, *slot_desc_len, 768]
# Pooling of the set of latent context representation
......@@ -427,6 +426,9 @@ class SetSUMBTHead(Module):
belief_embedding = belief_embedding.reshape(batch_size, dialogue_size, num_slots, -1)
# Perform classification
# Get padded batch, dialogue idx pairs
batches, dialogues = torch.where(attention_mask[:, 0, 0].reshape(batch_size, dialogue_size) == 0.0)
if self.config.predict_actions:
# User request prediction
request_probs = dict()
......@@ -435,8 +437,6 @@ class SetSUMBTHead(Module):
# Store output probabilities
request_logits = request_logits.reshape(batch_size, dialogue_size)
mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
batches, dialogues = torch.where(mask == 0.0)
# Set request scores to 0.0 for padded turns
request_logits[batches, dialogues] = 0.0
request_probs[slot] = torch.sigmoid(request_logits)
......@@ -466,8 +466,6 @@ class SetSUMBTHead(Module):
# Store output probabilities
active_domain_logits = active_domain_logits.reshape(batch_size, dialogue_size)
mask = attention_mask[0, :, 0].reshape(batch_size, dialogue_size)
batches, dialogues = torch.where(mask == 0.0)
active_domain_logits[batches, dialogues] = 0.0
active_domain_probs[domain] = torch.sigmoid(active_domain_logits)
......@@ -524,8 +522,6 @@ class SetSUMBTHead(Module):
belief_state_mutual_info[slot] = self.loss.logits_to_mutual_info(logits).reshape(batch_size, dialogue_size)
# Set padded turn probabilities to zero
mask = attention_mask[self.slot_ids[slot], :, 0].reshape(batch_size, dialogue_size)
batches, dialogues = torch.where(mask == 0.0)
probs_[batches, dialogues, :] = 0.0
belief_state_probs[slot] = probs_
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment