diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 8f72144ce37a970fe4855a19d5bc8002fc2b4034..a9e096e453ed0cffaba2549ec3cdbf352a08caaa 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -259,11 +259,12 @@ class VectorBase(Vector): if intent in ['nobook', 'nooffer'] and slot != 'none': mask_list[i] = 1.0 - if "book" in slot and intent == 'inform' and not self.state[domain][slot]: - mask_list[i] = 1.0 + if "book" in slot and intent == 'inform': + if not self.state.get(domain, {}).get(slot, {}): + mask_list[i] = 1.0 if domain == 'taxi': - if slot in self.state['taxi']: + if slot in self.state.get('taxi', {}): if not self.state['taxi'][slot] and intent == 'inform': mask_list[i] = 1.0