diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 588f087bcb18c790acd037b28ef081a28e456d7f..5928df001372156b53414c82504709c3f4eb3f32 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -15,19 +15,23 @@ DEBUG = False class UserActionPolicy(GenTUSUserActionPolicy): - def __init__(self, model_checkpoint, mode="language", only_action=False, max_turn=40, **kwargs): + def __init__(self, model_checkpoint, mode="language", max_turn=40, **kwargs): self.use_sentiment = kwargs.get("use_sentiment", False) self.add_persona = kwargs.get("add_persona", True) self.emotion_mid = kwargs.get("emotion_mid", False) - print("===== model status =====") - print("use_sentiment: ", self.use_sentiment) - print("add_persona: ", self.add_persona) - print("emotion_mid: ", self.emotion_mid) + if not os.path.exists(os.path.dirname(model_checkpoint)): os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), "https://zenodo.org/record/7801525/files/EmoUS_default.zip") + if mode == "language": + only_action = False + elif mode == "semantic": + only_action = True + else: + raise ValueError("mode should be language or semantic") + super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) weight = kwargs.get("weight", None) self.kg = KnowledgeGraph( @@ -96,7 +100,10 @@ class UserActionPolicy(GenTUSUserActionPolicy): raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) output = self._parse_output(raw_output) self.semantic_action = self._remove_illegal_action(output["action"]) - self.utterance = output["text"] + + if not self.only_action: + self.utterance = output["text"] + self.emotion = output["emotion"] if self.use_sentiment: self.sentiment = output["sentiment"] @@ -112,6 +119,9 @@ class UserActionPolicy(GenTUSUserActionPolicy): del inputs + if self.only_action: + return self.semantic_action + return self.utterance def _parse_output(self, in_str): @@ -230,10 +240,14 @@ class UserActionPolicy(GenTUSUserActionPolicy): elif not self.use_sentiment and self.emotion_mid: pos = self._act_emo( pos, model_input, mode, emotion_mode, allow_general_intent) - else: + else: # defalut method pos = self._emo_act( pos, model_input, mode, emotion_mode, allow_general_intent) + if self.only_action: + # return semantic action. Don't need to generate text + return self.vector.decode(self.seq[0, :pos]) + pos = self._update_seq(self.token_map.get_id("start_text"), pos) text = self._get_text(model_input, pos) @@ -393,6 +407,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): class UserPolicy(Policy): def __init__(self, model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default", + mode="language", sample=False, action_penalty=False, **kwargs): @@ -402,12 +417,10 @@ class UserPolicy(Policy): os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), "https://zenodo.org/record/7801525/files/EmoUS_default.zip") - only_action = False - mode = "language" + self.policy = UserActionPolicy( model_checkpoint, mode=mode, - only_action=only_action, action_penalty=action_penalty, **kwargs) self.policy.load(os.path.join( @@ -446,12 +459,14 @@ if __name__ == "__main__": import os from convlab.dialog_agent import PipelineAgent from convlab.util.custom_util import set_seed + import time use_sentiment, emotion_mid = False, False set_seed(0) # Test semantic level behaviour usr_policy = UserPolicy( # model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default + mode="semantic", sample=True, use_sentiment=use_sentiment, emotion_mid=emotion_mid) @@ -460,7 +475,7 @@ if __name__ == "__main__": usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') usr.init_session() print(usr.policy.get_goal()) - + start = time.time() print(usr.response([]), usr.policy.get_emotion()) # print(usr.policy.policy.goal.status) print(usr.response([["inform", "restaurant", "area", "centre"], @@ -469,4 +484,7 @@ if __name__ == "__main__": # print(usr.policy.policy.goal.status) print(usr.response([["request", "restaurant", "price range", "?"]]), usr.policy.get_emotion()) + end = time.time() + print("-"*50) + print("time: ", end - start) # print(usr.policy.policy.goal.status)