From 4c3d02fcb5f3693768464684c52b230507e473c6 Mon Sep 17 00:00:00 2001 From: Christian Geishauser <45534723+ChrisGeishauser@users.noreply.github.com> Date: Thu, 15 Dec 2022 11:12:36 +0100 Subject: [PATCH] =?UTF-8?q?changed=20masking=20such=20that=20it=20works=20?= =?UTF-8?q?also=20if=20belief=20state=20has=20less=20keys=E2=80=A6=20(#105?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * changed masking such that it works also if belief state has less keys then default state * Update vector_base.py Co-authored-by: Christian <christian.geishauser@hhu.de> Co-authored-by: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> --- convlab/policy/vector/vector_base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 8f72144c..d85cca10 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -259,11 +259,12 @@ class VectorBase(Vector): if intent in ['nobook', 'nooffer'] and slot != 'none': mask_list[i] = 1.0 - if "book" in slot and intent == 'inform' and not self.state[domain][slot]: - mask_list[i] = 1.0 + if "book" in slot and intent == 'inform': + if not self.state.get(domain, {}).get(slot, {}): + mask_list[i] = 1.0 if domain == 'taxi': - if slot in self.state['taxi']: + if slot in self.state.get('taxi', {}): if not self.state['taxi'][slot] and intent == 'inform': mask_list[i] = 1.0 @@ -315,7 +316,7 @@ class VectorBase(Vector): entities list: list of entities of the specified domain """ - constraints = self.state[domain] + constraints = self.state.get(domain, {}) # Leave slots out of constraints to find which slot constraint results in no entities being found for constraint_slot in constraints: -- GitLab