diff --git a/convlab/dst/setsumbt/modeling/setsumbt.py b/convlab/dst/setsumbt/modeling/setsumbt.py index 0249649f0840d66b0cec8a65c91aded906f62f85..4b67e35c3e8d2e0eaa5abcff2376d89ff67ae3cc 100644 --- a/convlab/dst/setsumbt/modeling/setsumbt.py +++ b/convlab/dst/setsumbt/modeling/setsumbt.py @@ -57,7 +57,7 @@ class SlotUtteranceMatching(Module): turn_embeddings = turn_embeddings.transpose(0, 1) key_padding_mask = (attention_mask[:, :, 0] == 0.0) - key_padding_mask[key_padding_mask[:, 0], :] = False + key_padding_mask[torch.clone(key_padding_mask)[:, 0], :] = False hidden, _ = self.attention(query=slot_embeddings, key=turn_embeddings, value=turn_embeddings, key_padding_mask=key_padding_mask) diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index eca7f1749369f9569d6b923312a93cd317e0701c..e40332048188b7f5a1f43e397896a8b6b201553d 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -346,17 +346,26 @@ class SetSUMBTTracker(DST): state_entropy = None # Construct request action prediction - request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] - request_acts = [slot.split('-', 1) for slot in request_acts] - request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + if request_probs is not None: + request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] + request_acts = [slot.split('-', 1) for slot in request_acts] + request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + else: + request_acts = list() # Construct active domain set - active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + if active_domain_probs is not None: + active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + else: + active_domains = dict() # Construct general domain action - general_acts = general_act_probs[0, 0, :].argmax(-1).item() - general_acts = [[], ['bye'], ['thank']][general_acts] - general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + if general_act_probs is not None: + general_acts = general_act_probs[0, 0, :].argmax(-1).item() + general_acts = [[], ['bye'], ['thank']][general_acts] + general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + else: + general_acts = list() user_acts = request_acts + general_acts