Skip to content
Snippets Groups Projects
Commit 9c68a78b authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Refactor add_name and remove cur_domain from vectoriser

parent efc44852
No related branches found
No related tags found
No related merge requests found
......@@ -43,7 +43,6 @@ class VectorBase(Vector):
self.character = character
self.name_history_flag = True
self.name_action_prev = []
self.cur_domain = None
self.requestable = ['request']
self.informable = ['inform', 'recommend']
......@@ -254,12 +253,12 @@ class VectorBase(Vector):
if intent in ['nobook', 'nooffer'] and slot != 'none':
mask_list[i] = 1.0
if "book" in slot and intent.lower() == 'inform' and not self.state[domain][slot]:
if "book" in slot and intent == 'inform' and not self.state[domain][slot]:
mask_list[i] = 1.0
if domain == 'taxi':
if slot in self.state['taxi']:
if not self.state['taxi'][slot] and intent.lower() == 'inform':
if not self.state['taxi'][slot] and intent == 'inform':
mask_list[i] = 1.0
return mask_list
......@@ -372,10 +371,8 @@ class VectorBase(Vector):
entities = {}
for domint in action:
domain, intent = domint.split('-')
if domain not in entities and domain.lower() not in ['general']:
if domain not in entities and domain not in ['general']:
entities[domain] = self.dbquery_domain(domain)
if self.cur_domain and self.cur_domain not in entities:
entities[self.cur_domain] = self.dbquery_domain(self.cur_domain)
# From db query find which slot causes no_offer
nooffer = [domint for domint in action if 'nooffer' in domint]
......@@ -418,28 +415,33 @@ class VectorBase(Vector):
name_inform = []
contains_name = False
# General Inform Condition for Naming
cur_inform = str(self.cur_domain) + '-inform'
cur_request = str(self.cur_domain) + '-request'
domain = [domint.split('-', 1)[0] for domint in action]
domain = [d for d in domain if d not in ['general']]
domain = domain[0] if domain else 'none'
if domain == 'none':
raise NameError('Domain not defined')
cur_inform = domain + '-inform'
cur_request = domain + '-request'
index = -1
if cur_inform in action:
for [item, idx] in action[cur_inform]:
if item == 'name':
for [slot, value_id] in action[cur_inform]:
if slot == 'name':
contains_name = True
elif self.cur_domain == 'train' and item == 'id':
elif domain == 'train' and slot == 'id':
contains_name = True
elif self.cur_domain == 'hospital':
elif domain == 'hospital':
contains_name = True
elif item == 'choice' and cur_request in action:
elif slot == 'choice' and cur_request in action:
contains_name = True
if index != -1 and index != idx and idx is not None:
if index != -1 and index != value_id and value_id is not None:
logging.debug(
"System is likely refering multiple entities within this turn")
index = idx
index = value_id
if contains_name == False:
if self.cur_domain == 'train':
if domain == 'train':
name_act = ['id', index]
else:
name_act = ['name', index]
......@@ -465,7 +467,7 @@ class VectorBase(Vector):
pointer_vector = np.zeros(6 * len(self.db_domains))
number_entities_dict = {}
for domain in self.db_domains:
entities = self.dbquery_domain(domain.lower())
entities = self.dbquery_domain(domain)
number_entities_dict[domain] = len(entities)
pointer_vector = self.one_hot_vector(
len(entities), domain, pointer_vector)
......
......@@ -43,8 +43,6 @@ class VectorBinary(VectorBase):
action = state['user_action']
for intent, domain, slot, value in action:
domain_active_dict[domain] = True
if domain in self.db_domains:
self.cur_domain = domain
opp_act_vec = self.vectorize_user_act(state)
last_act_vec = self.vectorize_system_act(state)
......
......@@ -145,8 +145,6 @@ class MultiWozVector(VectorBase):
action = state['user_action']
for intent, domain, slot, value in action:
domain_active_dict[domain] = True
if domain in self.db_domains:
self.cur_domain = domain
action = state['user_action'] if self.character == 'sys' else state['system_action']
opp_action = delexicalize_da(action, self.requestable)
......@@ -189,8 +187,6 @@ class MultiWozVector(VectorBase):
if 'active_domains' in state:
domain_active = state['active_domains'][domain.lower()]
domain_active_dict[domain] = domain_active
if domain in self.db_domains and domain_active:
self.cur_domain = domain
else:
if [slot for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value]:
domain_active_dict[domain] = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment