From 681e7b80214eb2037a49336bd8521910d3c46296 Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Mon, 23 Jan 2023 15:41:57 +0100 Subject: [PATCH] some additional changes to make crosswoz work with ddpt --- convlab/policy/vector/vector_base.py | 9 ++++++--- convlab/policy/vtrace_DPT/create_descriptions.py | 2 +- convlab/util/multiwoz/lexicalize.py | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 39d378d3..821f7271 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -302,9 +302,12 @@ class VectorBase(Vector): entities list: list of entities of the specified domain """ - constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \ - if domain in self.state else [] - return self.db.query(domain, constraints, topk=10) + #constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \ + # if domain in self.state else [] + state = self.state if domain in self.state else {domain: {}} + if domain.lower() == "general": + return [] + return self.db.query(domain, state, topk=10) def find_nooffer_slot(self, domain): """ diff --git a/convlab/policy/vtrace_DPT/create_descriptions.py b/convlab/policy/vtrace_DPT/create_descriptions.py index fa5f9731..c6e88dab 100644 --- a/convlab/policy/vtrace_DPT/create_descriptions.py +++ b/convlab/policy/vtrace_DPT/create_descriptions.py @@ -13,7 +13,7 @@ def create_description_dicts(name='multiwoz21'): default_state = ontology['state'] domains = list(ontology['domains'].keys()) - if name == "multiwoz21": + if name == "multiwoz21" or name == "crosswoz": db = load_database(name) db_domains = db.domains else: diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py index 2fe3f299..a8df1067 100755 --- a/convlab/util/multiwoz/lexicalize.py +++ b/convlab/util/multiwoz/lexicalize.py @@ -72,7 +72,10 @@ def lexicalize_da(meta, entities, state, requestable): slot_reverse = reverse_da_slot_name_map.get(pair[0], pair[0]) else: slot_reverse = reverse_da_slot_name_map['taxi'].get(pair[0], pair[0]) - slot_old = REF_SYS_DA[domain.capitalize()].get(slot_reverse, pair[0].lower()) + try: + slot_old = REF_SYS_DA[domain.capitalize()].get(slot_reverse, pair[0].lower()) + except: + slot_old = "" slot = pair[0] n = int(pair[1]) - 1 if len(entities[domain]) > n: -- GitLab