diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 5928df001372156b53414c82504709c3f4eb3f32..ad029dd9cb1e9f9bcc94cde7754073cfe15d5977 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -47,7 +47,6 @@ class UserActionPolicy(GenTUSUserActionPolicy): self.init_session() def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None): - # TODO emotion allow_general_intent = False self.model.eval() @@ -336,6 +335,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): return self.kg.get_sentiment(next_token_logits, mode) def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal", sentiment=None): + mode = "max" # emotion is always max next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) return self.kg.get_emotion(next_token_logits, mode, emotion_mode, sentiment) @@ -413,6 +413,8 @@ class UserPolicy(Policy): **kwargs): # self.config = config print("emoTUS model checkpoint: ", model_checkpoint) + if sample: + print("EmoUS will sample action, but emotion is always max") if not os.path.exists(os.path.dirname(model_checkpoint)): os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), diff --git a/convlab/policy/emoTUS/token_map.py b/convlab/policy/emoTUS/token_map.py index 407e3102cdeda2690461fa28805aef687868febd..1c8eef2fca42dbc36fc80e9e4c6355ccd94a0091 100644 --- a/convlab/policy/emoTUS/token_map.py +++ b/convlab/policy/emoTUS/token_map.py @@ -20,7 +20,8 @@ class tokenMap: 'end_act': '"]], "', 'start_text': 'text": "', 'end_json': '}', - 'end_json_2': '"}' + 'end_json_2': '"}', + 'book': 'book' } if only_action: diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py index c0ff690f01f52422fd64683964e46a18a51d0dcf..4ffe3571eee99c1e8bbe685e623db9c75124894c 100644 --- a/convlab/policy/genTUS/stepGenTUS.py +++ b/convlab/policy/genTUS/stepGenTUS.py @@ -223,6 +223,11 @@ class UserActionPolicy(Policy): # get slot slot = self._get_slot( model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + if "book" in slot["token_name"]: + pos = self._update_seq(self.token_map.get_id('book'), pos) + slot = self._get_book_slot( + model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode) + slot["token_name"] = "book" + slot["token_name"] pos = self._update_seq(slot["token_id"], pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos) @@ -252,6 +257,12 @@ class UserActionPolicy(Policy): is_mentioned = self.vector.is_mentioned(domain) return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned) + def _get_book_slot(self, model_input, generated_so_far, intent, domain, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + is_mentioned = self.vector.is_mentioned(domain) + return self.kg.get_book_slot(next_token_logits, intent, domain, mode, is_mentioned) + def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"): next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) diff --git a/convlab/policy/genTUS/token_map.py b/convlab/policy/genTUS/token_map.py index 7825c2880928c40f68284b0c3199932cd1cfc477..a6187318dc0ba37eea0318c5deebe4874f691fbf 100644 --- a/convlab/policy/genTUS/token_map.py +++ b/convlab/policy/genTUS/token_map.py @@ -14,11 +14,12 @@ class tokenMap: 'start_json': '{"action": [', # 49643, 10845, 7862, 646 'start_act': '["', # 49329 'sep_token': '", "', # 1297('",'), 22 - 'sep_act': '"], ["', # 49177 + 'sep_act': '"], ["', # 49177 'end_act': '"]], "', # 42248, 7479, 22 'start_text': 'text": "', # 29015, 7862, 22 - 'end_json': '}', # 24303 - 'end_json_2': '"}' # 48805 + 'end_json': '}', # 24303 + 'end_json_2': '"}', # 48805 + 'book': 'book' # 6298 } if only_action: self.format_tokens['end_act'] = '"]]}' diff --git a/convlab/policy/genTUS/unify/knowledge_graph.py b/convlab/policy/genTUS/unify/knowledge_graph.py index 1c59a7042e81880fd20bdbd84c6ea3acce96498f..81b30442773940d2f3b19529e9cc53f81a66d8ed 100644 --- a/convlab/policy/genTUS/unify/knowledge_graph.py +++ b/convlab/policy/genTUS/unify/knowledge_graph.py @@ -83,7 +83,7 @@ class KnowledgeGraph: if slot not in self.user_goal[domain]: self.user_goal[domain][slot] = [] - self.add_token(domain, "slot") + self.add_token(slot, "slot") if value not in self.user_goal[domain][slot]: value = json.dumps(str(value))[1:-1] @@ -204,6 +204,17 @@ class KnowledgeGraph: return token_map + def get_book_slot(self, outputs, intent, domain, mode="max", is_mentioned=False): + slot_list = self.candidate( + candidate_type="slot", intent=intent, domain=domain, is_mentioned=is_mentioned) + book_slot_list = [s.replace("book", "") + for s in slot_list if 'book' in s] + + token_map = self._get_max_domain_token( + outputs=outputs, candidates=book_slot_list, map_type="slot", mode=mode) + + return token_map + def get_value(self, outputs, intent, domain, slot, mode="max"): if intent in self.general_intent or slot.lower() == "none": token_name = "none"