From 4c3d02fcb5f3693768464684c52b230507e473c6 Mon Sep 17 00:00:00 2001
From: Christian Geishauser <45534723+ChrisGeishauser@users.noreply.github.com>
Date: Thu, 15 Dec 2022 11:12:36 +0100
Subject: [PATCH] =?UTF-8?q?changed=20masking=20such=20that=20it=20works=20?=
 =?UTF-8?q?also=20if=20belief=20state=20has=20less=20keys=E2=80=A6=20(#105?=
 =?UTF-8?q?)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* changed masking such that it works also if belief state has less keys then default state

* Update vector_base.py

Co-authored-by: Christian <christian.geishauser@hhu.de>
Co-authored-by: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com>
---
 convlab/policy/vector/vector_base.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py
index 8f72144c..d85cca10 100644
--- a/convlab/policy/vector/vector_base.py
+++ b/convlab/policy/vector/vector_base.py
@@ -259,11 +259,12 @@ class VectorBase(Vector):
             if intent in ['nobook', 'nooffer'] and slot != 'none':
                 mask_list[i] = 1.0
 
-            if "book" in slot and intent == 'inform' and not self.state[domain][slot]:
-                mask_list[i] = 1.0
+            if "book" in slot and intent == 'inform':
+                if not self.state.get(domain, {}).get(slot, {}):
+                    mask_list[i] = 1.0
 
             if domain == 'taxi':
-                if slot in self.state['taxi']:
+                if slot in self.state.get('taxi', {}):
                     if not self.state['taxi'][slot] and intent == 'inform':
                         mask_list[i] = 1.0
 
@@ -315,7 +316,7 @@ class VectorBase(Vector):
             entities list:
                 list of entities of the specified domain
         """
-        constraints = self.state[domain]
+        constraints = self.state.get(domain, {})
 
         # Leave slots out of constraints to find which slot constraint results in no entities being found
         for constraint_slot in constraints:
-- 
GitLab