From 2d0b95853d306a06ba12c2d913862554291e0d75 Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Thu, 13 Apr 2023 12:00:41 +0200 Subject: [PATCH] small bugfix when using sgd for mle since sgd has no booked information and not the full belief state is given in the dataset --- convlab/policy/mle/loader.py | 3 ++- convlab/policy/vector/vector_binary.py | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/convlab/policy/mle/loader.py b/convlab/policy/mle/loader.py index 9c10d2a7..5af5ddf3 100755 --- a/convlab/policy/mle/loader.py +++ b/convlab/policy/mle/loader.py @@ -105,7 +105,8 @@ class PolicyDataVectorizer: state['terminated'] = data_point['terminated'] if self.dst is not None and state['terminated']: self.dst.init_session() - state['booked'] = data_point['booked'] + if "booked" in data_point: + state['booked'] = data_point['booked'] dialogue_act = flatten_acts(data_point['dialogue_acts']) vectorized_state, mask = self.vector.state_vectorize(state) diff --git a/convlab/policy/vector/vector_binary.py b/convlab/policy/vector/vector_binary.py index c6b02a11..e444dec2 100755 --- a/convlab/policy/vector/vector_binary.py +++ b/convlab/policy/vector/vector_binary.py @@ -47,8 +47,15 @@ class VectorBinary(VectorBase): opp_act_vec = self.vectorize_user_act(state) last_act_vec = self.vectorize_system_act(state) belief_state, domain_active_dict = self.vectorize_belief_state(state, domain_active_dict) - book = self.vectorize_booked(state) - degree, number_entities_dict = self.pointer() + if "booked" in state: + book = self.vectorize_booked(state) + else: + book = [] + if self.db is not None: + degree, number_entities_dict = self.pointer() + else: + degree = [] + number_entities_dict = {} final = 1. if state['terminated'] else 0. state_vec = np.r_[opp_act_vec, last_act_vec, @@ -82,13 +89,14 @@ class VectorBinary(VectorBase): belief_state = np.zeros(self.belief_state_dim) i = 0 for domain in self.belief_domains: - for slot, value in state['belief_state'][domain].items(): - if value: - belief_state[i] = 1. - i += 1 - - if [slot for slot, value in state['belief_state'][domain].items() if value]: - domain_active_dict[domain] = True + if domain in state['belief_state']: + for slot, value in state['belief_state'][domain].items(): + if value: + belief_state[i] = 1. + i += 1 + + if [slot for slot, value in state['belief_state'][domain].items() if value]: + domain_active_dict[domain] = True return belief_state, domain_active_dict def vectorize_system_act(self, state): -- GitLab