Skip to content
Snippets Groups Projects
Select Git revision
  • 83bc73ee66fa5de79c25888326e3586b780acce5
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

stepGenTUS.py

Blame
  • Hsien-Chin Lin's avatar
    Hsien-Chin Lin authored and Carel van Niekerk committed
    f5ce0b8d
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    stepGenTUS.py 22.83 KiB
    import json
    import os
    
    import torch
    from transformers import BartTokenizer
    
    from convlab.policy.genTUS.ppo.vector import stepGenTUSVector
    from convlab.policy.genTUS.stepGenTUSmodel import stepGenTUSmodel
    from convlab.policy.genTUS.token_map import tokenMap
    from convlab.policy.genTUS.unify.Goal import Goal
    from convlab.policy.genTUS.unify.knowledge_graph import KnowledgeGraph
    from convlab.policy.policy import Policy
    from convlab.task.multiwoz.goal_generator import GoalGenerator
    from convlab.util.custom_util import model_downloader
    
    
    DEBUG = False
    
    
    class UserActionPolicy(Policy):
        def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
            self.mode = mode
            # if mode == "semantic" and only_action:
            #     # only generate semantic action in prediction
            print("model_checkpoint", model_checkpoint)
            self.only_action = only_action
            if self.only_action:
                print("change mode to semantic because only_action=True")
                self.mode = "semantic"
            self.max_in_len = 500
            self.max_out_len = 100 if only_action else 200
            max_act_len = kwargs.get("max_act_len", 2)
            print("max_act_len", max_act_len)
            self.max_action_len = max_act_len
            if "max_act_len" in kwargs:
                self.max_out_len = 30 * self.max_action_len
                print("max_act_len", self.max_out_len)
            self.max_turn = max_turn
            if mode not in ["semantic", "language"]:
                print("Unknown user mode")
    
            self.reward = {"success":  self.max_turn*2,
                           "fail": self.max_turn*-1}
            self.tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            train_whole_model = kwargs.get("whole_model", True)
            self.model = stepGenTUSmodel(
                model_checkpoint, train_whole_model=train_whole_model)
            self.model.eval()
            self.model.to(self.device)
            self.model.share_memory()
    
            self.turn_level_reward = kwargs.get("turn_level_reward", True)
            self.cooperative = kwargs.get("cooperative", True)
    
            dataset = kwargs.get("dataset", "")
            self.kg = KnowledgeGraph(
                tokenizer=self.tokenizer,
                dataset=dataset)
    
            self.goal_gen = GoalGenerator()
    
            self.vector = stepGenTUSVector(
                model_checkpoint, self.max_in_len, self.max_out_len)
            self.norm_reward = False
    
            self.action_penalty = kwargs.get("action_penalty", False)
            self.usr_act_penalize = kwargs.get("usr_act_penalize", 0)
            self.goal_list_type = kwargs.get("goal_list_type", "normal")
            self.update_mode = kwargs.get("update_mode", "normal")
            self.max_history = kwargs.get("max_history", 3)
            self.init_session()
    
        def _update_seq(self, sub_seq: list, pos: int):
            for x in sub_seq:
                self.seq[0, pos] = x
                pos += 1
    
            return pos
    
        def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True):
            # TODO no duplicate
            self.kg.parse_input(raw_inputs)
            model_input = self.vector.encode(raw_inputs, self.max_in_len)
            # start token
            self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
            pos = self._update_seq([0], 0)
            pos = self._update_seq(self.token_map.get_id('start_json'), pos)
            pos = self._update_seq(self.token_map.get_id('start_act'), pos)
    
            # get semantic actions
            for act_len in range(self.max_action_len):
                pos = self._get_semantic_action(
                    model_input, pos, mode, allow_general_intent)
    
                terminate, token_name = self._stop_semantic(
                    model_input, pos, act_len)
                pos = self._update_seq(self.token_map.get_id(token_name), pos)
    
                if terminate:
                    break
    
            if self.only_action:
                # return semantic action. Don't need to generate text
                return self.vector.decode(self.seq[0, :pos])
    
            # TODO remove illegal action here?
    
            # get text output
            pos = self._update_seq(self.token_map.get_id("start_text"), pos)
    
            text = self._get_text(model_input, pos)
    
            return text
    
        def generate_text_from_give_semantic(self, raw_inputs, semantic_action):
            self.kg.parse_input(raw_inputs)
            model_input = self.vector.encode(raw_inputs, self.max_in_len)
            self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
            pos = self._update_seq([0], 0)
            pos = self._update_seq(self.token_map.get_id('start_json'), pos)
            pos = self._update_seq(self.token_map.get_id('start_act'), pos)
    
            if len(semantic_action) == 0:
                pos = self._update_seq(self.token_map.get_id("end_act"), pos)
    
            for act_id, (intent, domain, slot, value) in enumerate(semantic_action):
                pos = self._update_seq(self.kg._get_token_id(intent), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.kg._get_token_id(domain), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.kg._get_token_id(slot), pos)
                pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
                pos = self._update_seq(self.kg._get_token_id(value), pos)
    
                if act_id == len(semantic_action) - 1:
                    token_name = "end_act"
                else:
                    token_name = "sep_act"
                pos = self._update_seq(self.token_map.get_id(token_name), pos)
            pos = self._update_seq(self.token_map.get_id("start_text"), pos)
    
            raw_output = self._get_text(model_input, pos)
            return self._parse_output(raw_output)["text"]
    
        def _get_text(self, model_input, pos):
            s_pos = pos
            for i in range(s_pos, self.max_out_len):
                next_token_logits = self.model.get_next_token_logits(
                    model_input, self.seq[:1, :pos])
                next_token = torch.argmax(next_token_logits, dim=-1)
    
                if self._stop_text(next_token):
                    # text = self.vector.decode(self.seq[0, s_pos:pos])
                    # text = self._norm_str(text)
                    # return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
                    break
    
                pos = self._update_seq([next_token], pos)
            text = self.vector.decode(self.seq[0, s_pos:pos])
            text = self._norm_str(text)
            return self.vector.decode(self.seq[0, :s_pos]) + text + '"}'
            # TODO return None
    
        def _stop_text(self, next_token):
            if next_token == self.token_map.get_id("end_json")[0]:
                return True
            elif next_token == self.token_map.get_id("end_json_2")[0]:
                return True
    
            return False
    
        @staticmethod
        def _norm_str(text: str):
            text = text.strip('"')
            text = text.replace('"', "'")
            text = text.replace('\\', "")
            return text
    
        def _stop_semantic(self, model_input, pos, act_length=0):
    
            outputs = self.model.get_next_token_logits(
                model_input, self.seq[:1, :pos])
            tokens = {}
            for token_name in ['sep_act', 'end_act']:
                tokens[token_name] = {
                    "token_id": self.token_map.get_id(token_name)}
                hash_id = tokens[token_name]["token_id"][0]
                tokens[token_name]["score"] = outputs[:, hash_id].item()
    
            if tokens['end_act']["score"] > tokens['sep_act']["score"]:
                terminate = True
            else:
                terminate = False
    
            if act_length >= self.max_action_len - 1:
                terminate = True
    
            token_name = "end_act" if terminate else "sep_act"
    
            return terminate, token_name
    
        def _get_semantic_action(self, model_input, pos, mode="max", allow_general_intent=True):
    
            intent = self._get_intent(
                model_input, self.seq[:1, :pos], mode, allow_general_intent)
            pos = self._update_seq(intent["token_id"], pos)
            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
    
            # get domain
            domain = self._get_domain(
                model_input, self.seq[:1, :pos], intent["token_name"], mode)
            pos = self._update_seq(domain["token_id"], pos)
            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
    
            # get slot
            slot = self._get_slot(
                model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], mode)
            pos = self._update_seq(slot["token_id"], pos)
            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
    
            # get value
    
            value = self._get_value(
                model_input, self.seq[:1, :pos], intent["token_name"], domain["token_name"], slot["token_name"], mode)
            pos = self._update_seq(value["token_id"], pos)
    
            return pos
    
        def _get_intent(self, model_input, generated_so_far, mode="max", allow_general_intent=True):
            next_token_logits = self.model.get_next_token_logits(
                model_input, generated_so_far)
    
            return self.kg.get_intent(next_token_logits, mode, allow_general_intent)
    
        def _get_domain(self, model_input, generated_so_far, intent, mode="max"):
            next_token_logits = self.model.get_next_token_logits(
                model_input, generated_so_far)
    
            return self.kg.get_domain(next_token_logits, intent, mode)
    
        def _get_slot(self, model_input, generated_so_far, intent, domain, mode="max"):
            next_token_logits = self.model.get_next_token_logits(
                model_input, generated_so_far)
            is_mentioned = self.vector.is_mentioned(domain)
            return self.kg.get_slot(next_token_logits, intent, domain, mode, is_mentioned)
    
        def _get_value(self, model_input, generated_so_far, intent, domain, slot, mode="max"):
            next_token_logits = self.model.get_next_token_logits(
                model_input, generated_so_far)
    
            return self.kg.get_value(next_token_logits, intent, domain, slot, mode)
    
        def _remove_illegal_action(self, action):
            # Transform illegal action to legal action
            new_action = []
            for act in action:
                if len(act) == 4:
                    if "<?>" in act[-1]:
                        act = [act[0], act[1], act[2], "?"]
                    if act not in new_action:
                        new_action.append(act)
                else:
                    print("illegal action:", action)
            return new_action
    
        def _parse_output(self, in_str):
            in_str = str(in_str)
            in_str = in_str.replace('<s>', '').replace(
                '<\\s>', '').replace('o"clock', "o'clock")
            action = {"action": [], "text": ""}
            try:
                action = json.loads(in_str)
            except:
                print("invalid action:", in_str)
                print("-"*20)
            return action
    
        def predict(self, sys_act, mode="max", allow_general_intent=True):
            # raw_sys_act = sys_act
            # sys_act = sys_act[:5]
            # update goal
            # TODO
            allow_general_intent = False
            self.model.eval()
    
            if not self.add_sys_from_reward:
                self.goal.update_user_goal(action=sys_act, char="sys")
                self.sys_acts.append(sys_act)  # for terminate conversation
    
            # update constraint
            self.time_step += 2
    
            history = []
            if self.usr_acts:
                if self.max_history == 1:
                    history = self.usr_acts[-1]
                else:
                    history = self.usr_acts[-1*self.max_history:]
            inputs = json.dumps({"system": sys_act,
                                 "goal": self.goal.get_goal_list(),
                                 "history": history,
                                 "turn": str(int(self.time_step/2))})
            with torch.no_grad():
                raw_output = self._generate_action(
                    raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
            output = self._parse_output(raw_output)
            self.semantic_action = self._remove_illegal_action(output["action"])
            if not self.only_action:
                self.utterance = output["text"]
    
            # TODO
            if self.is_finish():
                self.semantic_action, self.utterance = self._good_bye()
    
            # if self.is_finish():
            #     print("terminated")
    
            # if self.is_finish():
            #     good_bye = self._good_bye()
            #     self.goal.add_usr_da(good_bye)
            #     return good_bye
    
            self.goal.update_user_goal(action=self.semantic_action, char="usr")
            self.vector.update_mentioned_domain(self.semantic_action)
            self.usr_acts.append(self.semantic_action)
    
            # if self._usr_terminate(usr_action):
            #     print("terminated by user")
            #     self.terminated = True
    
            del inputs
    
            if self.mode == "language":
                # print("in", sys_act)
                # print("out", self.utterance)
                return self.utterance
            else:
                return self.semantic_action
    
        def init_session(self, goal=None):
            self.token_map = tokenMap(tokenizer=self.tokenizer)
            self.token_map.default(only_action=self.only_action)
            self.time_step = 0
            remove_domain = "police"  # remove police domain in inference
    
            if not goal:
                self._new_goal(remove_domain=remove_domain)
            else:
                self._read_goal(goal)
    
            self.vector.init_session(goal=self.goal)
    
            self.terminated = False
            self.add_sys_from_reward = False
            self.sys_acts = []
            self.usr_acts = []
            self.semantic_action = []
            self.utterance = ""
    
        def _read_goal(self, data_goal):
            self.goal = Goal(goal=data_goal)
    
        def _new_goal(self, remove_domain="police", domain_len=None):
            self.goal = Goal(goal_generator=self.goal_gen)
            # keep_generate_goal = True
            # # domain_len = 1
            # while keep_generate_goal:
            #     self.goal = Goal(goal_generator=self.goal_gen,
            #                      goal_list_type=self.goal_list_type,
            #                      update_mode=self.update_mode)
            #     if (domain_len and len(self.goal.domains) != domain_len) or \
            #             (remove_domain and remove_domain in self.goal.domains):
            #         keep_generate_goal = True
            #     else:
            #         keep_generate_goal = False
    
        def load(self, model_path):
            self.model.load_state_dict(torch.load(
                model_path, map_location=self.device))
            # self.model = BartForConditionalGeneration.from_pretrained(
            #     model_checkpoint)
    
        def get_goal(self):
            if self.goal.raw_goal is not None:
                return self.goal.raw_goal
            goal = {}
            for domain in self.goal.domain_goals:
                if domain not in goal:
                    goal[domain] = {}
                for intent in self.goal.domain_goals[domain]:
                    if intent == "inform":
                        slot_type = "info"
                    elif intent == "request":
                        slot_type = "reqt"
                    elif intent == "book":
                        slot_type = "book"
                    else:
                        print("unknown slot type")
                    if slot_type not in goal[domain]:
                        goal[domain][slot_type] = {}
                    for slot, value in self.goal.domain_goals[domain][intent].items():
                        goal[domain][slot_type][slot] = value
            return goal
    
        def get_reward(self, sys_response=None):
            self.add_sys_from_reward = False if sys_response is None else True
    
            if self.add_sys_from_reward:
                self.goal.update_user_goal(action=sys_response, char="sys")
                self.goal.add_sys_da(sys_response)  # for evaluation
                self.sys_acts.append(sys_response)  # for terminate conversation
    
            if self.is_finish():
                if self.is_success():
                    reward = self.reward["success"]
                    self.success = True
                else:
                    reward = self.reward["fail"]
                    self.success = False
    
            else:
                reward = -1
                if self.turn_level_reward:
                    reward += self.turn_reward()
    
                self.success = None
                # if self.action_penalty:
                #     reward += self._system_action_penalty()
    
            if self.norm_reward:
                reward = (reward - 20)/60
            return reward
    
        def _system_action_penalty(self):
            free_action_len = 3
            if len(self.sys_acts) < 1:
                return 0
            # TODO only penalize the slots not in user goal
            # else:
            #     penlaty = 0
            #     for i in range(len(self.sys_acts[-1])):
            #         penlaty += -1*i
            #     return penlaty
            if len(self.sys_acts[-1]) > 3:
                return -1*(len(self.sys_acts[-1])-free_action_len)
            return 0
    
        def turn_reward(self):
            r = 0
            r += self._new_act_reward()
            r += self._reply_reward()
            r += self._usr_act_len()
            return r
    
        def _usr_act_len(self):
            last_act = self.usr_acts[-1]
            penalty = 0
            if len(last_act) > 2:
                penalty = (2-len(last_act))*self.usr_act_penalize
            return penalty
    
        def _new_act_reward(self):
            last_act = self.usr_acts[-1]
            if last_act != self.semantic_action:
                print(f"---> why? last {last_act} usr {self.semantic_action}")
            new_act = []
            for act in last_act:
                if len(self.usr_acts) < 2:
                    break
                if act[1].lower() == "general":
                    new_act.append(0)
                elif act in self.usr_acts[-2]:
                    new_act.append(-1)
                elif act not in self.usr_acts[-2]:
                    new_act.append(1)
    
            return sum(new_act)
    
        def _reply_reward(self):
            if self.cooperative:
                return self._cooperative_reply_reward()
            else:
                return self._non_cooperative_reply_reward()
    
        def _non_cooperative_reply_reward(self):
            r = []
            reqts = []
            infos = []
            reply_len = 0
            max_len = 1
            for act in self.sys_acts[-1]:
                if act[0] == "request":
                    reqts.append([act[1], act[2]])
            for act in self.usr_acts[-1]:
                if act[0] == "inform":
                    infos.append([act[1], act[2]])
            for req in reqts:
                if req in infos:
                    if reply_len < max_len:
                        r.append(1)
                    elif reply_len == max_len:
                        r.append(0)
                    else:
                        r.append(-5)
    
            if r:
                return sum(r)
            return 0
    
        def _cooperative_reply_reward(self):
            r = []
            reqts = []
            infos = []
            for act in self.sys_acts[-1]:
                if act[0] == "request":
                    reqts.append([act[1], act[2]])
            for act in self.usr_acts[-1]:
                if act[0] == "inform":
                    infos.append([act[1], act[2]])
            for req in reqts:
                if req in infos:
                    r.append(1)
                else:
                    r.append(-1)
            if r:
                return sum(r)
            return 0
    
        def _usr_terminate(self):
            for act in self.semantic_action:
                if act[0] in ['thank', 'bye']:
                    return True
            return False
    
        def is_finish(self):
            # stop by model generation?
            if self._finish_conversation_rule():
                self.terminated = True
                return True
            elif self._usr_terminate():
                self.terminated = True
                return True
            self.terminated = False
            return False
    
        def is_success(self):
            task_complete = self.goal.task_complete()
            # goal_status = self.goal.all_mentioned()
            # should mentioned all slots
            if task_complete:  # and goal_status["complete"] > 0.6:
                return True
            return False
    
        def _good_bye(self):
            if self.is_success():
                return [['thank', 'general', 'none', 'none']], "thank you. bye"
                # if self.mode == "semantic":
                #     return [['thank', 'general', 'none', 'none']]
                # else:
                #     return "bye"
            else:
                return [["bye", "general", "None", "None"]], "bye"
                if self.mode == "semantic":
                    return [["bye", "general", "None", "None"]]
                return "bye"
    
        def _finish_conversation_rule(self):
            if self.is_success():
                return True
    
            if self.time_step > self.max_turn:
                return True
    
            if (len(self.sys_acts) > 4) and (self.sys_acts[-1] == self.sys_acts[-2]) and (self.sys_acts[-2] == self.sys_acts[-3]):
                return True
            return False
    
        def is_terminated(self):
            # Is there any action to say?
            self.is_finish()
            return self.terminated
    
    
    class UserPolicy(Policy):
        def __init__(self,
                     model_checkpoint,
                     mode="semantic",
                     only_action=True,
                     sample=False,
                     action_penalty=False,
                     **kwargs):
            # self.config = config
            if not os.path.exists(os.path.dirname(model_checkpoint)):
                os.mkdir(os.path.dirname(model_checkpoint))
                model_downloader(os.path.dirname(model_checkpoint),
                                 "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
    
            self.policy = UserActionPolicy(
                model_checkpoint,
                mode=mode,
                only_action=only_action,
                action_penalty=action_penalty,
                **kwargs)
            self.policy.load(os.path.join(
                model_checkpoint, "pytorch_model.bin"))
            self.sample = sample
    
        def predict(self, sys_act, mode="max"):
            if self.sample:
                mode = "sample"
            else:
                mode = "max"
            response = self.policy.predict(sys_act, mode)
            return response
    
        def init_session(self, goal=None):
            self.policy.init_session(goal)
    
        def is_terminated(self):
            return self.policy.is_terminated()
    
        def get_reward(self, sys_response=None):
            return self.policy.get_reward(sys_response)
    
        def get_goal(self):
            if hasattr(self.policy, 'get_goal'):
                return self.policy.get_goal()
            return None
    
    
    if __name__ == "__main__":
        import os
    
        from convlab.dialog_agent import PipelineAgent
        # from convlab.nlu.jointBERT.multiwoz import BERTNLU
        from convlab.util.custom_util import set_seed
    
        set_seed(20220220)
        # Test semantic level behaviour
        model_checkpoint = 'convlab/policy/genTUS/unify/experiments/multiwoz21-exp'
        usr_policy = UserPolicy(
            model_checkpoint,
            mode="semantic")
        # usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
        usr_nlu = None  # BERTNLU()
        usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
        print(usr.policy.get_goal())
    
        print(usr.response([]))
        print(usr.policy.policy.goal.status)
        print(usr.response([["request", "attraction", "area", "?"]]))
        print(usr.policy.policy.goal.status)
        print(usr.response([["request", "attraction", "area", "?"]]))
        print(usr.policy.policy.goal.status)