Skip to content
Snippets Groups Projects
Unverified Commit b5eae17f authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Bug fix setsumbt resulting from new torch feature (#117)

parent 57477760
Branches
No related tags found
Loading
......@@ -57,7 +57,7 @@ class SlotUtteranceMatching(Module):
turn_embeddings = turn_embeddings.transpose(0, 1)
key_padding_mask = (attention_mask[:, :, 0] == 0.0)
key_padding_mask[key_padding_mask[:, 0], :] = False
key_padding_mask[torch.clone(key_padding_mask)[:, 0], :] = False
hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
key_padding_mask=key_padding_mask)
......
......@@ -346,17 +346,26 @@ class SetSUMBTTracker(DST):
state_entropy = None
# Construct request action prediction
if request_probs is not None:
request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
request_acts = [slot.split('-', 1) for slot in request_acts]
request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
else:
request_acts = list()
# Construct active domain set
if active_domain_probs is not None:
active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
else:
active_domains = dict()
# Construct general domain action
if general_act_probs is not None:
general_acts = general_act_probs[0, 0, :].argmax(-1).item()
general_acts = [[], ['bye'], ['thank']][general_acts]
general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
else:
general_acts = list()
user_acts = request_acts + general_acts
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment