Skip to content
Snippets Groups Projects
Commit 681e7b80 authored by Christian's avatar Christian
Browse files

some additional changes to make crosswoz work with ddpt

parent d18cf121
No related branches found
No related tags found
No related merge requests found
......@@ -302,9 +302,12 @@ class VectorBase(Vector):
entities list:
list of entities of the specified domain
"""
constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \
if domain in self.state else []
return self.db.query(domain, constraints, topk=10)
#constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \
# if domain in self.state else []
state = self.state if domain in self.state else {domain: {}}
if domain.lower() == "general":
return []
return self.db.query(domain, state, topk=10)
def find_nooffer_slot(self, domain):
"""
......
......@@ -13,7 +13,7 @@ def create_description_dicts(name='multiwoz21'):
default_state = ontology['state']
domains = list(ontology['domains'].keys())
if name == "multiwoz21":
if name == "multiwoz21" or name == "crosswoz":
db = load_database(name)
db_domains = db.domains
else:
......
......@@ -72,7 +72,10 @@ def lexicalize_da(meta, entities, state, requestable):
slot_reverse = reverse_da_slot_name_map.get(pair[0], pair[0])
else:
slot_reverse = reverse_da_slot_name_map['taxi'].get(pair[0], pair[0])
try:
slot_old = REF_SYS_DA[domain.capitalize()].get(slot_reverse, pair[0].lower())
except:
slot_old = ""
slot = pair[0]
n = int(pair[1]) - 1
if len(entities[domain]) > n:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment