From b5eae17f05c604de4099e42a487648760a1e166f Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> Date: Tue, 10 Jan 2023 16:54:04 +0100 Subject: [PATCH] Bug fix setsumbt resulting from new torch feature (#117) --- convlab/dst/setsumbt/modeling/setsumbt.py | 2 +- convlab/dst/setsumbt/tracker.py | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py index 0249649f..4b67e35c 100644 --- a/convlab/dst/setsumbt/modeling/setsumbt.py +++ b/convlab/dst/setsumbt/modeling/setsumbt.py @@ -57,7 +57,7 @@ class SlotUtteranceMatching(Module): turn_embeddings = turn_embeddings.transpose(0, 1) key_padding_mask = (attention_mask[:, :, 0] == 0.0) - key_padding_mask[key_padding_mask[:, 0], :] = False + key_padding_mask[torch.clone(key_padding_mask)[:, 0], :] = False hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings, key_padding_mask=key_padding_mask) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index eca7f174..e4033204 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -346,17 +346,26 @@ class SetSUMBTTracker(DST): state_entropy = None # Construct request action prediction - request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] - request_acts = [slot.split('-', 1) for slot in request_acts] - request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + if request_probs is not None: + request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] + request_acts = [slot.split('-', 1) for slot in request_acts] + request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + else: + request_acts = list() # Construct active domain set - active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + if active_domain_probs is not None: + active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + else: + active_domains = dict() # Construct general domain action - general_acts = general_act_probs[0, 0, :].argmax(-1).item() - general_acts = [[], ['bye'], ['thank']][general_acts] - general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + if general_act_probs is not None: + general_acts = general_act_probs[0, 0, :].argmax(-1).item() + general_acts = [[], ['bye'], ['thank']][general_acts] + general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + else: + general_acts = list() user_acts = request_acts + general_acts -- GitLab