Skip to content
Snippets Groups Projects
Commit 6923cac0 authored by Carel van Niekerk's avatar Carel van Niekerk :desktop:
Browse files

Merge branch 'setsumbt_bug' into 'github_master'

Fix SetSUMBT bug resulting from new torch feature

See merge request dsml/convlab/ConvLab3!59
parents 317dc9b8 031a1a0b
No related branches found
No related tags found
No related merge requests found
......@@ -57,7 +57,7 @@ class SlotUtteranceMatching(Module):
turn_embeddings = turn_embeddings.transpose(0, 1)
key_padding_mask = (attention_mask[:, :, 0] == 0.0)
key_padding_mask[key_padding_mask[:, 0], :] = False
key_padding_mask[torch.clone(key_padding_mask)[:, 0], :] = False
hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings,
key_padding_mask=key_padding_mask)
......
......@@ -346,17 +346,26 @@ class SetSUMBTTracker(DST):
state_entropy = None
# Construct request action prediction
if request_probs is not None:
request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
request_acts = [slot.split('-', 1) for slot in request_acts]
request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
else:
request_acts = list()
# Construct active domain set
if active_domain_probs is not None:
active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
else:
active_domains = dict()
# Construct general domain action
if general_act_probs is not None:
general_acts = general_act_probs[0, 0, :].argmax(-1).item()
general_acts = [[], ['bye'], ['thank']][general_acts]
general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
else:
general_acts = list()
user_acts = request_acts + general_acts
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment