From 681e7b80214eb2037a49336bd8521910d3c46296 Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Mon, 23 Jan 2023 15:41:57 +0100
Subject: [PATCH] some additional changes to make crosswoz work with ddpt

---
 convlab/policy/vector/vector_base.py             | 9 ++++++---
 convlab/policy/vtrace_DPT/create_descriptions.py | 2 +-
 convlab/util/multiwoz/lexicalize.py              | 5 ++++-
 3 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py
index 39d378d3..821f7271 100644
--- a/convlab/policy/vector/vector_base.py
+++ b/convlab/policy/vector/vector_base.py
@@ -302,9 +302,12 @@ class VectorBase(Vector):
             entities list:
                 list of entities of the specified domain
         """
-        constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \
-            if domain in self.state else []
-        return self.db.query(domain, constraints, topk=10)
+        #constraints = [[slot, value] for slot, value in self.state[domain].items() if value] \
+        #    if domain in self.state else []
+        state = self.state if domain in self.state else {domain: {}}
+        if domain.lower() == "general":
+            return []
+        return self.db.query(domain, state, topk=10)
 
     def find_nooffer_slot(self, domain):
         """
diff --git a/convlab/policy/vtrace_DPT/create_descriptions.py b/convlab/policy/vtrace_DPT/create_descriptions.py
index fa5f9731..c6e88dab 100644
--- a/convlab/policy/vtrace_DPT/create_descriptions.py
+++ b/convlab/policy/vtrace_DPT/create_descriptions.py
@@ -13,7 +13,7 @@ def create_description_dicts(name='multiwoz21'):
     default_state = ontology['state']
     domains = list(ontology['domains'].keys())
 
-    if name == "multiwoz21":
+    if name == "multiwoz21" or name == "crosswoz":
         db = load_database(name)
         db_domains = db.domains
     else:
diff --git a/convlab/util/multiwoz/lexicalize.py b/convlab/util/multiwoz/lexicalize.py
index 2fe3f299..a8df1067 100755
--- a/convlab/util/multiwoz/lexicalize.py
+++ b/convlab/util/multiwoz/lexicalize.py
@@ -72,7 +72,10 @@ def lexicalize_da(meta, entities, state, requestable):
                         slot_reverse = reverse_da_slot_name_map.get(pair[0], pair[0])
                     else:
                         slot_reverse = reverse_da_slot_name_map['taxi'].get(pair[0], pair[0])
-                    slot_old = REF_SYS_DA[domain.capitalize()].get(slot_reverse, pair[0].lower())
+                    try:
+                        slot_old = REF_SYS_DA[domain.capitalize()].get(slot_reverse, pair[0].lower())
+                    except:
+                        slot_old = ""
                     slot = pair[0]
                     n = int(pair[1]) - 1
                     if len(entities[domain]) > n:
-- 
GitLab