diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 39ebe0050c3fa2f9cfd5f2c8eda73bcf563919c5..8590317da67d3d03eb73b24e4a73d17fbf998bf5 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -15,11 +15,12 @@ DEBUG = False class UserActionPolicy(GenTUSUserActionPolicy): - def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs): + def __init__(self, model_checkpoint, mode="language", only_action=False, max_turn=40, **kwargs): self.use_sentiment = kwargs.get("use_sentiment", False) - print("use_sentiment", self.use_sentiment) + self.add_persona = kwargs.get("add_persona", False) + self.emotion_mid = kwargs.get("emotion_mid", False) + super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) - print("sentiment", self.use_sentiment) weight = kwargs.get("weight", None) self.kg = KnowledgeGraph( tokenizer=self.tokenizer, @@ -52,22 +53,17 @@ class UserActionPolicy(GenTUSUserActionPolicy): else: history = self.usr_acts[-1*self.max_history:] - # TODO add user info? impolite? -> check self.use_sentiment - if self.use_sentiment: - # TODO how to get event and user politeness? - input_dict = {"system": sys_act, - "goal": self.goal.get_goal_list(), - "history": history, - "turn": str(int(self.time_step/2))} + input_dict = {"system": sys_act, + "goal": self.goal.get_goal_list(), + "history": history, + "turn": str(int(self.time_step/2))} + + if self.add_persona: for user, info in self.user_info.items(): input_dict[user] = info - inputs = json.dumps(input_dict) - else: - inputs = json.dumps({"system": sys_act, - "goal": self.goal.get_goal_list(), - "history": history, - "turn": str(int(self.time_step/2))}) + inputs = json.dumps(input_dict) + with torch.no_grad(): if emotion == "all": raw_output = self.generate_from_emotion( @@ -91,16 +87,12 @@ class UserActionPolicy(GenTUSUserActionPolicy): raw_output = self._generate_action( raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) output = self._parse_output(raw_output) + print(output) self.semantic_action = self._remove_illegal_action(output["action"]) - if not self.only_action: - self.utterance = output["text"] - + self.utterance = output["text"] self.emotion = output["emotion"] if self.use_sentiment: self.sentiment = output["sentiment"] - # print("---> sentiment", self.sentiment) - # print("---> emotion", self.emotion) - # print("---> self.utterance", self.utterance) if self.is_finish(): self.emotion, self.semantic_action, self.utterance = self._good_bye() @@ -113,12 +105,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): del inputs - if self.mode == "language": - # print("in", sys_act) - # print("out", self.utterance) - return self.utterance - else: - return self.semantic_action + return self.utterance def _parse_output(self, in_str): in_str = str(in_str) @@ -135,25 +122,24 @@ class UserActionPolicy(GenTUSUserActionPolicy): print("-"*20) return action - def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="normal"): - self.kg.parse_input(raw_inputs) - model_input = self.vector.encode(raw_inputs, self.max_in_len) - # start token - self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() - pos = self._update_seq([0], 0) - pos = self._update_seq(self.token_map.get_id('start_json'), pos) - if self.use_sentiment: - sentiment = self._get_sentiment( - model_input, self.seq[:1, :pos], mode) - pos = self._update_seq(sentiment["token_id"], pos) - else: - emotion = self._get_emotion( - model_input, self.seq[:1, :pos], mode, emotion_mode) - pos = self._update_seq(emotion["token_id"], pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - pos = self._update_seq(self.token_map.get_id('start_act'), pos) - - # get semantic actions + def _update_sentiment(self, pos, model_input, mode): + pos = self._update_seq( + self.token_map.get_id('start_sentiment'), pos) + sentiment = self._get_sentiment( + model_input, self.seq[:1, :pos], mode) + pos = self._update_seq(sentiment["token_id"], pos) + return sentiment, pos + + def _update_emotion(self, pos, model_input, mode, emotion_mode, sentiment=None): + pos = self._update_seq( + self.token_map.get_id('start_emotion'), pos) + emotion = self._get_emotion( + model_input, self.seq[:1, :pos], mode, emotion_mode, sentiment) + pos = self._update_seq(emotion["token_id"], pos) + return pos + + def _update_semantic_act(self, pos, model_input, mode, allow_general_intent): + mode = "max" for act_len in range(self.max_action_len): pos = self._get_semantic_action( model_input, pos, mode, allow_general_intent) @@ -164,16 +150,82 @@ class UserActionPolicy(GenTUSUserActionPolicy): if terminate: break + return pos - if self.only_action: - return self.vector.decode(self.seq[0, :pos]) + def _sent_act_emo(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # sent + sentiment, pos = self._update_sentiment(pos, model_input, mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode, sentiment["token_name"]) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) - if self.use_sentiment: - pos = self._update_seq(self.token_map.get_id('start_emotion'), pos) - emotion = self._get_emotion( - model_input, self.seq[:1, :pos], mode, emotion_mode, sentiment["token_name"]) - pos = self._update_seq(emotion["token_id"], pos) - pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + return pos + + def _sent_emo_act(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # sent + sentiment, pos = self._update_sentiment(pos, model_input, mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode, sentiment["token_name"]) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + + return pos + + def _emo_act(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + + return pos + + def _act_emo(self, pos, model_input, mode, emotion_mode, allow_general_intent): + # act + pos = self._update_seq(self.token_map.get_id('start_act'), pos) + pos = self._update_semantic_act( + pos, model_input, mode, allow_general_intent) + # emo + pos = self._update_emotion( + pos, model_input, mode, emotion_mode) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + + return pos + + def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="normal"): + self.kg.parse_input(raw_inputs) + model_input = self.vector.encode(raw_inputs, self.max_in_len) + # start token + self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() + pos = self._update_seq([0], 0) + pos = self._update_seq(self.token_map.get_id('start_json'), pos) + + if self.use_sentiment and self.emotion_mid: + pos = self._sent_act_emo( + pos, model_input, mode, emotion_mode, allow_general_intent) + elif self.use_sentiment and not self.emotion_mid: + pos = self._sent_emo_act( + pos, model_input, mode, emotion_mode, allow_general_intent) + elif not self.use_sentiment and self.emotion_mid: + pos = self._act_emo( + pos, model_input, mode, emotion_mode, allow_general_intent) + else: + pos = self._emo_act( + pos, model_input, mode, emotion_mode, allow_general_intent) pos = self._update_seq(self.token_map.get_id("start_text"), pos) text = self._get_text(model_input, pos) @@ -332,8 +384,6 @@ class UserActionPolicy(GenTUSUserActionPolicy): class UserPolicy(Policy): def __init__(self, model_checkpoint, - mode="semantic", - only_action=True, sample=False, action_penalty=False, **kwargs): @@ -342,7 +392,8 @@ class UserPolicy(Policy): os.mkdir(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip") - + only_action = False + mode = "language" self.policy = UserActionPolicy( model_checkpoint, mode=mode, @@ -385,15 +436,15 @@ if __name__ == "__main__": # from convlab.nlu.jointBERT.multiwoz import BERTNLU from convlab.util.custom_util import set_seed - set_seed(20220220) + use_sentiment, emotion_mid = True, True + set_seed(0) # Test semantic level behaviour model_checkpoint = 'convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17' usr_policy = UserPolicy( model_checkpoint, - mode="language", - only_action=False, - use_sentiment=True, - sample=True) + sample=True, + use_sentiment=use_sentiment, + emotion_mid=emotion_mid) # usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin")) usr_nlu = None # BERTNLU() usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 8aa4606ab63fe8ed80c432ca7e8a1294474de884..86af813d755ad549d44b6106cfe50018a27127c3 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -28,35 +28,34 @@ def arg_parser(): default="") parser.add_argument("--generated-file", type=str, help="the generated results", default="") - parser.add_argument("--only-action", action="store_true") parser.add_argument("--dataset", default="multiwoz") - parser.add_argument("--do-semantic", action="store_true", - help="do semantic evaluation") - parser.add_argument("--do-nlg", action="store_true", - help="do nlg generation") parser.add_argument("--do-golden-nlg", action="store_true", help="do golden nlg generation") - parser.add_argument("--no-neutral", action="store_true", - help="skip neutral emotion") parser.add_argument("--use-sentiment", action="store_true") + parser.add_argument("--emotion-mid", action="store_true") parser.add_argument("--weight", type=float, default=None) return parser.parse_args() class Evaluator: - def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False, use_sentiment=False, weight=None): + def __init__(self, model_checkpoint, dataset, model_weight=None, **kwargs): self.dataset = dataset self.model_checkpoint = model_checkpoint self.model_weight = model_weight self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" - self.use_sentiment = use_sentiment + self.use_sentiment = kwargs.get("use_sentiment", False) + self.add_persona = kwargs.get("add_persona", False) + self.emotion_mid = kwargs.get("emotion_mid", False) + weight = kwargs.get("weight", None) self.usr = UserActionPolicy( model_checkpoint, - only_action=only_action, dataset=self.dataset, - use_sentiment=use_sentiment, + use_sentiment=self.use_sentiment, + add_persona=self.add_persona, + emotion_mid=self.emotion_mid, weight=weight) + self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) self.r = {"input": [], @@ -66,7 +65,8 @@ class Evaluator: "gen_acts": [], "gen_utts": [], "gen_emotion": []} - if use_sentiment: + + if self.use_sentiment: self.r["golden_sentiment"] = [] self.r["gen_sentiment"] = [] @@ -81,17 +81,13 @@ class Evaluator: for x in self.r: self.r[x].append(temp[x]) - def generate_results(self, f_eval, golden=False, no_neutral=False): + def generate_results(self, f_eval, golden=False): emotion_mode = "normal" - if no_neutral: - emotion_mode = "no_neutral" in_file = json.load(open(f_eval)) - for dialog in tqdm(in_file['dialog']): + for dialog in tqdm(in_file['dialog'][:2]): inputs = dialog["in"] labels = self.usr._parse_output(dialog["out"]) - if no_neutral and labels["emotion"].lower() == "neutral": - continue if golden: usr_act = labels["action"] @@ -138,10 +134,10 @@ class Evaluator: result.append(temp) return result - def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False): + def nlg_evaluation(self, input_file=None, generated_file=None, golden=False): if input_file: print("Force generation") - self.generate_results(input_file, golden, no_neutral) + self.generate_results(input_file, golden) elif generated_file: self.read_generated_result(generated_file) @@ -240,7 +236,7 @@ class Evaluator: for metric in scores: result[metric] = sum(scores[metric])/len(scores[metric]) print(f"{metric}: {result[metric]}") - # TODO no neutral + emo_score = emotion_score( golden_emotions, gen_emotions, @@ -338,27 +334,23 @@ def main(): eval = Evaluator(args.model_checkpoint, args.dataset, args.model_weight, - args.only_action, - args.use_sentiment, + use_sentiment=args.use_sentiment, + emotion_mid=args.emotion_mid, weight=args.weight) print("model checkpoint", args.model_checkpoint) print("generated_file", args.generated_file) print("input_file", args.input_file) with torch.no_grad(): - if args.do_semantic: - eval.evaluation(args.input_file) - if args.do_nlg: - if args.generated_file: - generated_file = args.generated_file - else: - nlg_result = eval.nlg_evaluation(input_file=args.input_file, - generated_file=args.generated_file, - golden=args.do_golden_nlg, - no_neutral=args.no_neutral) - - generated_file = nlg_result - eval.evaluation(args.input_file, - generated_file) + if args.generated_file: + generated_file = args.generated_file + else: + nlg_result = eval.nlg_evaluation(input_file=args.input_file, + generated_file=args.generated_file, + golden=args.do_golden_nlg) + + generated_file = nlg_result + eval.evaluation(args.input_file, + generated_file) if __name__ == '__main__': diff --git a/convlab/policy/emoTUS/token_map.py b/convlab/policy/emoTUS/token_map.py index 63576d5748f356bcecc768f94e0d2a18bb71833c..407e3102cdeda2690461fa28805aef687868febd 100644 --- a/convlab/policy/emoTUS/token_map.py +++ b/convlab/policy/emoTUS/token_map.py @@ -2,28 +2,26 @@ import json class tokenMap: - def __init__(self, tokenizer, use_sentiment=False): + def __init__(self, tokenizer, **kwargs): self.tokenizer = tokenizer self.token_name = {} self.hash_map = {} self.debug = False - self.use_sentiment = use_sentiment self.default() def default(self, only_action=False): self.format_tokens = { - 'start_json': '{"emotion": "', # 49643, 10845, 7862, 646 - 'start_act': 'action": [["', # 49329 - 'sep_token': '", "', # 1297('",'), 22 - 'sep_act': '"], ["', # 49177 - 'end_act': '"]], "', # 42248, 7479, 22 - 'start_text': 'text": "', # 29015, 7862, 22 - 'end_json': '}', # 24303 - 'end_json_2': '"}' # 48805 + 'start_json': '{"', + 'start_sentiment': 'sentiment": "', + 'start_emotion': 'emotion": "', + 'start_act': 'action": [["', + 'sep_token': '", "', + 'sep_act': '"], ["', + 'end_act': '"]], "', + 'start_text': 'text": "', + 'end_json': '}', + 'end_json_2': '"}' } - if self.use_sentiment: - self.format_tokens['start_json'] = '{"sentiment": "' - self.format_tokens['start_emotion'] = 'emotion": "' if only_action: self.format_tokens['end_act'] = '"]]}'