Skip to content
Snippets Groups Projects
Select Git revision
  • ac5532846030325797480896124f2a3e4bd47dd6
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

emoTUS.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    emoTUS.py 12.68 KiB
    import os
    import json
    
    import torch
    
    from convlab.policy.emoTUS.token_map import tokenMap
    from convlab.policy.emoTUS.unify.knowledge_graph import KnowledgeGraph
    from convlab.policy.genTUS.stepGenTUS import \
        UserActionPolicy as GenTUSUserActionPolicy
    from convlab.policy.policy import Policy
    from convlab.util.custom_util import model_downloader
    
    DEBUG = False
    
    
    class UserActionPolicy(GenTUSUserActionPolicy):
        def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
            super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs)
    
            self.kg = KnowledgeGraph(
                tokenizer=self.tokenizer,
                dataset="emowoz")
            data_emotion = json.load(open("convlab/policy/emoTUS/emotion.json"))
            self.emotion_list = [""]*len(data_emotion)
            for emotion, index in data_emotion.items():
                self.emotion_list[index] = emotion
    
            self.init_session()
    
        def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None):
            # TODO emotion
            allow_general_intent = False
            self.model.eval()
    
            if not self.add_sys_from_reward:
                self.goal.update_user_goal(action=sys_act, char="sys")
                self.sys_acts.append(sys_act)  # for terminate conversation
    
            # update constraint
            self.time_step += 2
    
            history = []
            if self.usr_acts:
                if self.max_history == 1:
                    history = self.usr_acts[-1]
                else:
                    history = self.usr_acts[-1*self.max_history:]
            inputs = json.dumps({"system": sys_act,
                                 "goal": self.goal.get_goal_list(),
                                 "history": history,
                                 "turn": str(int(self.time_step/2))})
            with torch.no_grad():
                if emotion == "all":
                    raw_output = self.generate_from_emotion(
                        raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
                    for emo in raw_output:
                        output = self._parse_output(raw_output[emo])
                        print("emo:", emo)
                        print("act:", output["action"])
                        print("utt:", output["text"])
                    raw_output = raw_output["Neutral"]
                elif emotion is not None:
                    raw_output = self.generate_from_emotion(
                        raw_inputs=inputs, emotion=emotion, mode=mode, allow_general_intent=allow_general_intent)
                    for emo in raw_output:
                        output = self._parse_output(raw_output[emo])
                        print("emo:", emo)
                        print("act:", output["action"])
                        print("utt:", output["text"])
                    raw_output = raw_output[emotion]
                else:
                    raw_output = self._generate_action(
                        raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
            output = self._parse_output(raw_output)
            self.emotion = output["emotion"]
            print(self.emotion)
            self.semantic_action = self._remove_illegal_action(output["action"])
            if not self.only_action:
                self.utterance = output["text"]
    
            if self.is_finish():
                self.emotion, self.semantic_action, self.utterance = self._good_bye()
    
            self.goal.update_user_goal(action=self.semantic_action, char="usr")
            self.vector.update_mentioned_domain(self.semantic_action)
            self.usr_acts.append(self.semantic_action)
    
            del inputs
    
            if self.mode == "language":
                # print("in", sys_act)
                # print("out", self.utterance)
                return self.utterance
            else:
                return self.semantic_action
    
        def _parse_output(self, in_str):
            in_str = str(in_str)
            in_str = in_str.replace('<s>', '').replace(
                '<\\s>', '').replace('o"clock', "o'clock")
            action = {"emotion": "Neutral", "action": [], "text": ""}
            try:
                action = json.loads(in_str)
            except:
                print("invalid action:", in_str)
                print("-"*20)
            return action
    
        def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="max"):
            self.kg.parse_input(raw_inputs)
            model_input = self.vector.encode(raw_inputs, self.max_in_len)
            # start token
            self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
            pos = self._update_seq([0], 0)
            pos = self._update_seq(self.token_map.get_id('start_json'), pos)
            emotion = self._get_emotion(
                model_input, self.seq[:1, :pos], mode, emotion_mode)
            pos = self._update_seq(emotion["token_id"], pos)
            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
            pos = self._update_seq(self.token_map.get_id('start_act'), pos)
    
            # get semantic actions
            for act_len in range(self.max_action_len):
                pos = self._get_semantic_action(
                    model_input, pos, mode, allow_general_intent)
    
                terminate, token_name = self._stop_semantic(
                    model_input, pos, act_len)
                pos = self._update_seq(self.token_map.get_id(token_name), pos)
    
                if terminate:
                    break
    
            if self.only_action:
                return self.vector.decode(self.seq[0, :pos])
    
            pos = self._update_seq(self.token_map.get_id("start_text"), pos)
            text = self._get_text(model_input, pos)
    
            return text
    
        def generate_from_emotion(self, raw_inputs,  emotion=None, mode="max", allow_general_intent=True):
            self.kg.parse_input(raw_inputs)
            model_input = self.vector.encode(raw_inputs, self.max_in_len)
            responses = {}
            if emotion:
                print("if emotion")
                emotion_list = [emotion]
            else:
                emotion_list = self.emotion_list
                print(emotion_list)
            for emotion in emotion_list:
                # start token
                print("emotion", emotion)
                self.seq = torch.zeros(1, self.max_out_len,
                                       device=self.device).long()
                pos = self._update_seq([0], 0)
                pos = self._update_seq(self.token_map.get_id('start_json'), pos)
    
                pos = self._update_seq(self.kg._get_token_id(emotion), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.token_map.get_id('start_act'), pos)
    
                # get semantic actions
                for act_len in range(self.max_action_len):
                    pos = self._get_semantic_action(
                        model_input, pos, mode, allow_general_intent)
    
                    terminate, token_name = self._stop_semantic(
                        model_input, pos, act_len)
                    pos = self._update_seq(self.token_map.get_id(token_name), pos)
    
                    if terminate:
                        break
    
                if self.only_action:
                    return self.vector.decode(self.seq[0, :pos])
    
                pos = self._update_seq(self.token_map.get_id("start_text"), pos)
                text = self._get_text(model_input, pos)
                responses[emotion] = text
    
            return responses
    
        def generate_text_from_give_semantic(self, raw_inputs, semantic_action, emotion="Neutral"):
            self.kg.parse_input(raw_inputs)
            model_input = self.vector.encode(raw_inputs, self.max_in_len)
            self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
            pos = self._update_seq([0], 0)
            pos = self._update_seq(self.token_map.get_id('start_json'), pos)
            pos = self._update_seq(self.kg._get_token_id(emotion), pos)
            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
            pos = self._update_seq(self.token_map.get_id('start_act'), pos)
    
            if len(semantic_action) == 0:
                pos = self._update_seq(self.token_map.get_id("end_act"), pos)
    
            for act_id, (intent, domain, slot, value) in enumerate(semantic_action):
                pos = self._update_seq(self.kg._get_token_id(intent), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.kg._get_token_id(domain), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.kg._get_token_id(slot), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.kg._get_token_id(value), pos)
    
                if act_id == len(semantic_action) - 1:
                    token_name = "end_act"
                else:
                    token_name = "sep_act"
                pos = self._update_seq(self.token_map.get_id(token_name), pos)
            pos = self._update_seq(self.token_map.get_id("start_text"), pos)
    
            raw_output = self._get_text(model_input, pos)
            return self._parse_output(raw_output)["text"]
    
        def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"):
            next_token_logits = self.model.get_next_token_logits(
                model_input, generated_so_far)
            return self.kg.get_emotion(next_token_logits, mode, emotion_mode)
    
        def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
            next_token_logits = self.model.get_next_token_logits(
                model_input, generated_so_far)
    
            return self.kg.get_intent(next_token_logits, mode, allow_general_intent)
    
        def init_session(self, goal=None):
            self.token_map = tokenMap(tokenizer=self.tokenizer)
            self.token_map.default(only_action=self.only_action)
            self.time_step = 0
            remove_domain = "police"  # remove police domain in inference
    
            if not goal:
                self._new_goal(remove_domain=remove_domain)
            else:
                self._read_goal(goal)
    
            self.vector.init_session(goal=self.goal)
    
            self.terminated = False
            self.add_sys_from_reward = False
            self.sys_acts = []
            self.usr_acts = []
            self.semantic_action = []
            self.utterance = ""
            self.emotion = "Neutral"
    
        def _good_bye(self):
            # add emotion
            if self.is_success():
                return "Satisfied", [['thank', 'general', 'none', 'none']], "thank you. bye"
            else:
                return "Dissatisfied", [["bye", "general", "None", "None"]], "bye"
    
    
    class UserPolicy(Policy):
        def __init__(self,
                     model_checkpoint,
                     mode="semantic",
                     only_action=True,
                     sample=False,
                     action_penalty=False,
                     **kwargs):
            # self.config = config
            if not os.path.exists(os.path.dirname(model_checkpoint)):
                os.mkdir(os.path.dirname(model_checkpoint))
                model_downloader(os.path.dirname(model_checkpoint),
                                 "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
    
            self.policy = UserActionPolicy(
                model_checkpoint,
                mode=mode,
                only_action=only_action,
                action_penalty=action_penalty,
                **kwargs)
            self.policy.load(os.path.join(
                model_checkpoint, "pytorch_model.bin"))
            self.sample = sample
    
        def predict(self, sys_act, mode="max"):
            if self.sample:
                mode = "sample"
            else:
                mode = "max"
            response = self.policy.predict(sys_act, mode)
            return response
    
        def init_session(self, goal=None):
            self.policy.init_session(goal)
    
        def is_terminated(self):
            return self.policy.is_terminated()
    
        def get_reward(self, sys_response=None):
            return self.policy.get_reward(sys_response)
    
        def get_goal(self):
            if hasattr(self.policy, 'get_goal'):
                return self.policy.get_goal()
            return None
    
    
    if __name__ == "__main__":
        import os
    
        from convlab.dialog_agent import PipelineAgent
        # from convlab.nlu.jointBERT.multiwoz import BERTNLU
        from convlab.util.custom_util import set_seed
    
        set_seed(20220220)
        # Test semantic level behaviour
        model_checkpoint = 'convlab/policy/emoTUS/unify/experiments/emowoz_0_1/22-12-05-11-23'
        usr_policy = UserPolicy(
            model_checkpoint,
            mode="language",
            only_action=False)
        # usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
        usr_nlu = None  # BERTNLU()
        usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
        print(usr.policy.get_goal())
    
        print(usr.response([]))
        # print(usr.policy.policy.goal.status)
        print(usr.response([["inform", "restaurant", "area", "centre"],
                            ["request", "restaurant", "food", "?"]]))
        # print(usr.policy.policy.goal.status)
        print(usr.response([["request", "restaurant", "price range", "?"]]))
        # print(usr.policy.policy.goal.status)