From 269b24cb00129afef8c65a63c778d0046c689348 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Thu, 12 Jan 2023 01:31:59 +0100 Subject: [PATCH] wip --- convlab/policy/emoTUS/emoTUS.py | 45 +++++++-- convlab/policy/emoTUS/evaluate.py | 157 +++++++++++++++++------------ convlab/policy/emoTUS/token_map.py | 7 +- 3 files changed, 134 insertions(+), 75 deletions(-) diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 9ecbedd5..9f97eb0e 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -25,6 +25,8 @@ class UserActionPolicy(GenTUSUserActionPolicy): for emotion, index in data_emotion.items(): self.emotion_list[index] = emotion + self.use_sentiment = kwargs.get("use_sentiment", False) + self.init_session() def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None): @@ -46,11 +48,18 @@ class UserActionPolicy(GenTUSUserActionPolicy): else: history = self.usr_acts[-1*self.max_history:] - # TODO add user info? impolite? - inputs = json.dumps({"system": sys_act, - "goal": self.goal.get_goal_list(), - "history": history, - "turn": str(int(self.time_step/2))}) + # TODO add user info? impolite? -> check self.use_sentiment + if self.use_sentiment: + # TODO how to get event and user politeness? + inputs = json.dumps({"system": sys_act, + "goal": self.goal.get_goal_list(), + "history": history, + "turn": str(int(self.time_step/2))}) + else: + inputs = json.dumps({"system": sys_act, + "goal": self.goal.get_goal_list(), + "history": history, + "turn": str(int(self.time_step/2))}) with torch.no_grad(): if emotion == "all": raw_output = self.generate_from_emotion( @@ -101,6 +110,9 @@ class UserActionPolicy(GenTUSUserActionPolicy): in_str = in_str.replace('<s>', '').replace( '<\\s>', '').replace('o"clock', "o'clock") action = {"emotion": "Neutral", "action": [], "text": ""} + if self.use_sentiment: + action["sentiment"] = "Neutral" + try: action = json.loads(in_str) except: @@ -115,9 +127,14 @@ class UserActionPolicy(GenTUSUserActionPolicy): self.seq = torch.zeros(1, self.max_out_len, device=self.device).long() pos = self._update_seq([0], 0) pos = self._update_seq(self.token_map.get_id('start_json'), pos) - emotion = self._get_emotion( - model_input, self.seq[:1, :pos], mode, emotion_mode) - pos = self._update_seq(emotion["token_id"], pos) + if self.use_sentiment: + sentiment = self._get_sentiment( + model_input, self.seq[:1, :pos], mode) + pos = self._update_seq(sentiment["token_id"], pos) + else: + emotion = self._get_emotion( + model_input, self.seq[:1, :pos], mode, emotion_mode) + pos = self._update_seq(emotion["token_id"], pos) pos = self._update_seq(self.token_map.get_id('sep_token'), pos) pos = self._update_seq(self.token_map.get_id('start_act'), pos) @@ -136,6 +153,13 @@ class UserActionPolicy(GenTUSUserActionPolicy): if self.only_action: return self.vector.decode(self.seq[0, :pos]) + if self.use_sentiment: + pos = self._update_seq(self.token_map.get_id('start_emotion'), pos) + emotion = self._get_emotion( + model_input, self.seq[:1, :pos], mode, emotion_mode) + pos = self._update_seq(emotion["token_id"], pos) + pos = self._update_seq(self.token_map.get_id('sep_token'), pos) + pos = self._update_seq(self.token_map.get_id("start_text"), pos) text = self._get_text(model_input, pos) @@ -216,6 +240,11 @@ class UserActionPolicy(GenTUSUserActionPolicy): raw_output = self._get_text(model_input, pos) return self._parse_output(raw_output)["text"] + def _get_sentiment(self, model_input, generated_so_far, mode="max"): + next_token_logits = self.model.get_next_token_logits( + model_input, generated_so_far) + return self.kg.get_sentiment(next_token_logits, mode) + def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"): next_token_logits = self.model.get_next_token_logits( model_input, generated_so_far) diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 3b01e589..04baca8d 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -42,117 +42,111 @@ def arg_parser(): class Evaluator: - def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False): + def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False, use_sentiment=False): self.dataset = dataset self.model_checkpoint = model_checkpoint self.model_weight = model_weight self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" - # if model_weight: - # self.usr_policy = UserPolicy( - # self.model_checkpoint, only_action=only_action) - # self.usr_policy.load(model_weight) - # self.usr = self.usr_policy.usr - # else: + self.use_sentiment = use_sentiment + self.usr = UserActionPolicy( - model_checkpoint, only_action=only_action, dataset=self.dataset) + model_checkpoint, only_action=only_action, dataset=self.dataset, use_sentiment=use_sentiment) self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) + self.r = {"input": [], + "golden_acts": [], + "golden_utts": [], + "golden_emotion": [], + "gen_acts": [], + "gen_utts": [], + "gen_emotion": []} + if use_sentiment: + self.r["golden_sentiment"] = [] + self.r["gen_sentiment"] = [] + + def _append_result(self, temp): + for x in self.r: + self.r[x].append(temp[x]) + def generate_results(self, f_eval, golden=False, no_neutral=False): emotion_mode = "max" if no_neutral: emotion_mode = "no_neutral" in_file = json.load(open(f_eval)) - r = { - "input": [], - "golden_acts": [], - "golden_utts": [], - "golden_emotion": [], - "gen_acts": [], - "gen_utts": [], - "gen_emotion": [] - } + for dialog in tqdm(in_file['dialog']): inputs = dialog["in"] labels = self.usr._parse_output(dialog["out"]) - if no_neutral: - if labels["emotion"].lower() == "neutral": - print("skip") - continue - print("do", labels["emotion"]) + if no_neutral and labels["emotion"].lower() == "neutral": + continue + if golden: usr_act = labels["action"] usr_utt = self.usr.generate_text_from_give_semantic( inputs, labels["action"], labels["emotion"]) else: - output = self.usr._parse_output( self.usr._generate_action(inputs, emotion_mode=emotion_mode)) usr_emo = output["emotion"] usr_act = self.usr._remove_illegal_action(output["action"]) usr_utt = output["text"] - r["input"].append(inputs) - r["golden_acts"].append(labels["action"]) - r["golden_utts"].append(labels["text"]) - r["golden_emotion"].append(labels["emotion"]) - r["gen_acts"].append(usr_act) - r["gen_utts"].append(usr_utt) - r["gen_emotion"].append(usr_emo) + temp = {} + temp["input"] = inputs + temp["golden_acts"] = labels["action"] + temp["golden_utts"] = labels["text"] + temp["golden_emotion"] = labels["emotion"] + + temp["gen_acts"] = usr_act + temp["gen_utts"] = usr_utt + temp["gen_emotion"] = usr_emo - return r + if self.use_sentiment: + temp["golden_sentiment"] = labels["sentiment"] + temp["gen_sentiment"] = output["sentiment"] + + self._append_result(temp) def read_generated_result(self, f_eval): in_file = json.load(open(f_eval)) - r = { - "input": [], - "golden_acts": [], - "golden_utts": [], - "golden_emotion": [], - "gen_acts": [], - "gen_utts": [], - "gen_emotion": [] - } + for dialog in tqdm(in_file['dialog']): for x in dialog: - r[x].append(dialog[x]) - - return r + self.r[x].append(dialog[x]) + + def _transform_result(self): + index = [x for x in self.r] + result = [] + for i in range(len(self.r[index[0]])): + temp = {} + for x in index: + temp[x] = self.r[x][i] + result.append(temp) + return result def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False): if input_file: print("Force generation") - gen_r = self.generate_results(input_file, golden, no_neutral) + self.generate_results(input_file, golden, no_neutral) elif generated_file: - gen_r = self.read_generated_result(generated_file) + self.read_generated_result(generated_file) else: print("You must specify the input_file or the generated_file") nlg_eval = { "golden": golden, "metrics": {}, - "dialog": [] + "dialog": self._transform_result() } - for input, golden_act, golden_utt, golden_emo, gen_act, gen_utt, gen_emo in zip( - gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["golden_emotion"], - gen_r["gen_acts"], gen_r["gen_utts"], gen_r["gen_emotion"]): - nlg_eval["dialog"].append({ - "input": input, - "golden_acts": golden_act, - "golden_utts": golden_utt, - "golden_emotion": golden_emo, - "gen_acts": gen_act, - "gen_utts": gen_utt, - "gen_emotion": gen_emo - }) if golden: print("Calculate BLEU") bleu_metric = load_metric("sacrebleu") - labels = [[utt] for utt in gen_r["golden_utts"]] + labels = [[utt] for utt in self.r["golden_utts"]] - bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"], + bleu_score = bleu_metric.compute(predictions=self.r["gen_utts"], references=labels, force=True) print("bleu_metric", bleu_score) @@ -161,7 +155,7 @@ class Evaluator: else: print("Calculate SER") missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( - gen_r["gen_acts"], gen_r["gen_utts"]) + self.r["gen_acts"], self.r["gen_utts"]) print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format( "genTUSNLG", missing, total, hallucinate, missing/total)) @@ -171,7 +165,8 @@ class Evaluator: dir_name = self.model_checkpoint json.dump(nlg_eval, - open(os.path.join(dir_name, f"{self.time}-nlg_eval.json"), 'w'), + open(os.path.join( + dir_name, f"{self.time}-nlg_eval.json"), 'w'), indent=2) return os.path.join(dir_name, f"{self.time}-nlg_eval.json") @@ -232,8 +227,19 @@ class Evaluator: result[metric] = sum(scores[metric])/len(scores[metric]) print(f"{metric}: {result[metric]}") # TODO no neutral - emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint, - time=self.time, no_neutral=True) + emo_score = emotion_score( + golden_emotions, + gen_emotions, + self.model_checkpoint, + time=self.time, + no_neutral=False) + if self.use_sentiment: + sent_score = sentiment_score( + gen_file['dialog']["golden_sentiment"], + gen_file['dialog']["gen_sentiment"], + self.model_checkpoint, + time=self.time) + # for metric in emo_score: # result[metric] = emo_score[metric] # print(f"{metric}: {result[metric]}") @@ -254,7 +260,8 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro") sep_f1 = metrics.f1_score( golden_emotions, gen_emotions, average=None, labels=labels) - cm = metrics.confusion_matrix(golden_emotions, gen_emotions, normalize="true", labels=labels) + cm = metrics.confusion_matrix( + golden_emotions, gen_emotions, normalize="true", labels=labels) disp = metrics.ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=labels) disp.plot() @@ -265,6 +272,26 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra return r +def sentiment_score(golden_sentiment, gen_sentiment, dirname=".", time=""): + labels = ["Neutral", "Negative", "Positive"] + + print(labels) + macro_f1 = metrics.f1_score( + golden_sentiment, gen_sentiment, average="macro") + sep_f1 = metrics.f1_score( + golden_sentiment, gen_sentiment, average=None, labels=labels) + cm = metrics.confusion_matrix( + golden_sentiment, gen_sentiment, normalize="true", labels=labels) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, display_labels=labels) + disp.plot() + plt.savefig(os.path.join(dirname, f"{time}-sentiment.png")) + r = {"macro_f1": float(macro_f1), "sep_f1": list( + sep_f1), "cm": [list(c) for c in list(cm)]} + print(r) + return r + + def f1_measure(preds, labels): tp = 0 score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0} diff --git a/convlab/policy/emoTUS/token_map.py b/convlab/policy/emoTUS/token_map.py index face1e35..b8b9f885 100644 --- a/convlab/policy/emoTUS/token_map.py +++ b/convlab/policy/emoTUS/token_map.py @@ -2,7 +2,7 @@ import json class tokenMap: - def __init__(self, tokenizer): + def __init__(self, tokenizer, use_sentiment=False): self.tokenizer = tokenizer self.token_name = {} self.hash_map = {} @@ -10,7 +10,6 @@ class tokenMap: self.default() def default(self, only_action=False): - # TODO self.format_tokens = { 'start_json': '{"emotion": "', # 49643, 10845, 7862, 646 'start_act': 'action": [["', # 49329 @@ -21,6 +20,10 @@ class tokenMap: 'end_json': '}', # 24303 'end_json_2': '"}' # 48805 } + if self.use_sentiment: + self.format_tokens['start_json'] = '{"sentiment": "' + self.format_tokens['start_emotion'] = 'emotion": "' + if only_action: self.format_tokens['end_act'] = '"]]}' for token_name in self.format_tokens: -- GitLab