From acaae59a95c5c11a89861352f6e3f0dba38c12e4 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Thu, 13 Apr 2023 13:33:44 +0200 Subject: [PATCH] fix book slot issue --- convlab/policy/emoTUS/emoTUS.py | 4 +++- convlab/policy/emoTUS/token_map.py | 3 ++- convlab/policy/genTUS/stepGenTUS.py | 11 +++++++++++ convlab/policy/genTUS/token_map.py | 7 ++++--- convlab/policy/genTUS/unify/knowledge_graph.py | 13 ++++++++++++- 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 5928df00..ad029dd9 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -47,7 +47,6 @@ class UserActionPolicy(GenTUSUserActionPolicy): self.init_session() def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None): - # TODO emotion allow_general_intent = False self.model.eval() @@ -336,6 +335,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): return self.kg.get_sentiment(next_token_logits, mode) def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal", sentiment=None): + mode = "max" # emotion is always max next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) return self.kg.get_emotion(next_token_logits, mode, emotion_mode, sentiment) @@ -413,6 +413,8 @@ class UserPolicy(Policy): **kwargs): # self.config = config print("emoTUS model checkpoint: ", model_checkpoint) + if sample: + print("EmoUS will sample action, but emotion is always max") if not os.path.exists(os.path.dirname(model_checkpoint)): os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), diff --git a/convlab/policy/emoTUS/token_map.py b/convlab/policy/emoTUS/token_map.py index 407e3102..1c8eef2f 100644 --- a/convlab/policy/emoTUS/token_map.py +++ b/convlab/policy/emoTUS/token_map.py @@ -20,7 +20,8 @@ class tokenMap: 'end_act': '"]], "', 'start_text': 'text": "', 'end_json': '}', - 'end_json_2': '"}' + 'end_json_2': '"}', + 'book': 'book' } if only_action: diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py index c0ff690f..4ffe3571 100644 --- a/convlab/policy/genTUS/stepGenTUS.py +++ b/convlab/policy/genTUS/stepGenTUS.py @@ -223,6 +223,11 @@ class UserActionPolicy(Policy): # get slot slot = self._get_slot( model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + if "book" in slot["token_name"]: + pos = self._update_seq(self.token_map.get_id('book'), pos) + slot = self._get_book_slot( + model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + slot["token_name"] = "book" + slot["token_name"] pos = self._update_seq(slot["token_id"], pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos) @@ -252,6 +257,12 @@ class UserActionPolicy(Policy): is_mentioned = self.vector.is_mentioned(domain) return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned) + def _get_book_slot(self, model_input, generated_so_far, intent, domain, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + is_mentioned = self.vector.is_mentioned(domain) + return self.kg.get_book_slot(next_token_logits, intent, domain, mode, is_mentioned) + def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"): next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py index 7825c288..a6187318 100644 --- a/convlab/policy/genTUS/token_map.py +++ b/convlab/policy/genTUS/token_map.py @@ -14,11 +14,12 @@ class tokenMap: 'start_json': '{"action": [', # 49643, 10845, 7862, 646 'start_act': '["', # 49329 'sep_token': '", "', # 1297('",'), 22 - 'sep_act': '"], ["', # 49177 + 'sep_act': '"], ["', # 49177 'end_act': '"]], "', # 42248, 7479, 22 'start_text': 'text": "', # 29015, 7862, 22 - 'end_json': '}', # 24303 - 'end_json_2': '"}' # 48805 + 'end_json': '}', # 24303 + 'end_json_2': '"}', # 48805 + 'book': 'book' # 6298 } if only_action: self.format_tokens['end_act'] = '"]]}' diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py index 1c59a704..81b30442 100644 --- a/convlab/policy/genTUS/unify/knowledge_graph.py +++ b/convlab/policy/genTUS/unify/knowledge_graph.py @@ -83,7 +83,7 @@ class KnowledgeGraph: if slot not in self.user_goal[domain]: self.user_goal[domain][slot] = [] - self.add_token(domain, "slot") + self.add_token(slot, "slot") if value not in self.user_goal[domain][slot]: value = json.dumps(str(value))[1:-1] @@ -204,6 +204,17 @@ class KnowledgeGraph: return token_map + def get_book_slot(self, outputs, intent, domain, mode="max", is_mentioned=False): + slot_list = self.candidate( + candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned) + book_slot_list = [s.replace("book", "") + for s in slot_list if 'book' in s] + + token_map = self._get_max_domain_token( + outputs=outputs, candidates=book_slot_list, map_type="slot", mode=mode) + + return token_map + def get_value(self, outputs, intent, domain, slot, mode="max"): if intent in self.general_intent or slot.lower() == "none": token_name = "none" -- GitLab