From 269b24cb00129afef8c65a63c778d0046c689348 Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Thu, 12 Jan 2023 01:31:59 +0100
Subject: [PATCH] wip

---
 convlab/policy/emoTUS/emoTUS.py    |  45 +++++++--
 convlab/policy/emoTUS/evaluate.py  | 157 +++++++++++++++++------------
 convlab/policy/emoTUS/token_map.py |   7 +-
 3 files changed, 134 insertions(+), 75 deletions(-)

diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py
index 9ecbedd5..9f97eb0e 100644
--- a/convlab/policy/emoTUS/emoTUS.py
+++ b/convlab/policy/emoTUS/emoTUS.py
@@ -25,6 +25,8 @@ class UserActionPolicy(GenTUSUserActionPolicy):
         for emotion, index in data_emotion.items():
             self.emotion_list[index] = emotion
 
+        self.use_sentiment = kwargs.get("use_sentiment", False)
+
         self.init_session()
 
     def predict(self, sys_act, mode="max", allow_general_intent=True, emotion=None):
@@ -46,11 +48,18 @@ class UserActionPolicy(GenTUSUserActionPolicy):
             else:
                 history = self.usr_acts[-1*self.max_history:]
 
-        # TODO add user info? impolite?
-        inputs = json.dumps({"system": sys_act,
-                             "goal": self.goal.get_goal_list(),
-                             "history": history,
-                             "turn": str(int(self.time_step/2))})
+        # TODO add user info? impolite? -> check self.use_sentiment
+        if self.use_sentiment:
+            # TODO how to get event and user politeness?
+            inputs = json.dumps({"system": sys_act,
+                                 "goal": self.goal.get_goal_list(),
+                                 "history": history,
+                                 "turn": str(int(self.time_step/2))})
+        else:
+            inputs = json.dumps({"system": sys_act,
+                                 "goal": self.goal.get_goal_list(),
+                                 "history": history,
+                                 "turn": str(int(self.time_step/2))})
         with torch.no_grad():
             if emotion == "all":
                 raw_output = self.generate_from_emotion(
@@ -101,6 +110,9 @@ class UserActionPolicy(GenTUSUserActionPolicy):
         in_str = in_str.replace('<s>', '').replace(
             '<\\s>', '').replace('o"clock', "o'clock")
         action = {"emotion": "Neutral", "action": [], "text": ""}
+        if self.use_sentiment:
+            action["sentiment"] = "Neutral"
+
         try:
             action = json.loads(in_str)
         except:
@@ -115,9 +127,14 @@ class UserActionPolicy(GenTUSUserActionPolicy):
         self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
         pos = self._update_seq([0], 0)
         pos = self._update_seq(self.token_map.get_id('start_json'), pos)
-        emotion = self._get_emotion(
-            model_input, self.seq[:1, :pos], mode, emotion_mode)
-        pos = self._update_seq(emotion["token_id"], pos)
+        if self.use_sentiment:
+            sentiment = self._get_sentiment(
+                model_input, self.seq[:1, :pos], mode)
+            pos = self._update_seq(sentiment["token_id"], pos)
+        else:
+            emotion = self._get_emotion(
+                model_input, self.seq[:1, :pos], mode, emotion_mode)
+            pos = self._update_seq(emotion["token_id"], pos)
         pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
         pos = self._update_seq(self.token_map.get_id('start_act'), pos)
 
@@ -136,6 +153,13 @@ class UserActionPolicy(GenTUSUserActionPolicy):
         if self.only_action:
             return self.vector.decode(self.seq[0, :pos])
 
+        if self.use_sentiment:
+            pos = self._update_seq(self.token_map.get_id('start_emotion'), pos)
+            emotion = self._get_emotion(
+                model_input, self.seq[:1, :pos], mode, emotion_mode)
+            pos = self._update_seq(emotion["token_id"], pos)
+            pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
+
         pos = self._update_seq(self.token_map.get_id("start_text"), pos)
         text = self._get_text(model_input, pos)
 
@@ -216,6 +240,11 @@ class UserActionPolicy(GenTUSUserActionPolicy):
         raw_output = self._get_text(model_input, pos)
         return self._parse_output(raw_output)["text"]
 
+    def _get_sentiment(self, model_input, generated_so_far, mode="max"):
+        next_token_logits = self.model.get_next_token_logits(
+            model_input, generated_so_far)
+        return self.kg.get_sentiment(next_token_logits, mode)
+
     def _get_emotion(self, model_input, generated_so_far, mode="max", emotion_mode="normal"):
         next_token_logits = self.model.get_next_token_logits(
             model_input, generated_so_far)
diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py
index 3b01e589..04baca8d 100644
--- a/convlab/policy/emoTUS/evaluate.py
+++ b/convlab/policy/emoTUS/evaluate.py
@@ -42,117 +42,111 @@ def arg_parser():
 
 
 class Evaluator:
-    def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False):
+    def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False, use_sentiment=False):
         self.dataset = dataset
         self.model_checkpoint = model_checkpoint
         self.model_weight = model_weight
         self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
-        # if model_weight:
-        #     self.usr_policy = UserPolicy(
-        #         self.model_checkpoint, only_action=only_action)
-        #     self.usr_policy.load(model_weight)
-        #     self.usr = self.usr_policy.usr
-        # else:
+        self.use_sentiment = use_sentiment
+
         self.usr = UserActionPolicy(
-            model_checkpoint, only_action=only_action, dataset=self.dataset)
+            model_checkpoint, only_action=only_action, dataset=self.dataset, use_sentiment=use_sentiment)
         self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
 
+        self.r = {"input": [],
+                  "golden_acts": [],
+                  "golden_utts": [],
+                  "golden_emotion": [],
+                  "gen_acts": [],
+                  "gen_utts": [],
+                  "gen_emotion": []}
+        if use_sentiment:
+            self.r["golden_sentiment"] = []
+            self.r["gen_sentiment"] = []
+
+    def _append_result(self, temp):
+        for x in self.r:
+            self.r[x].append(temp[x])
+
     def generate_results(self, f_eval, golden=False, no_neutral=False):
         emotion_mode = "max"
         if no_neutral:
             emotion_mode = "no_neutral"
         in_file = json.load(open(f_eval))
-        r = {
-            "input": [],
-            "golden_acts": [],
-            "golden_utts": [],
-            "golden_emotion": [],
-            "gen_acts": [],
-            "gen_utts": [],
-            "gen_emotion": []
-        }
+
         for dialog in tqdm(in_file['dialog']):
             inputs = dialog["in"]
             labels = self.usr._parse_output(dialog["out"])
-            if no_neutral:
-                if labels["emotion"].lower() == "neutral":
-                    print("skip")
-                    continue
-            print("do", labels["emotion"])
+            if no_neutral and labels["emotion"].lower() == "neutral":
+                continue
+
             if golden:
                 usr_act = labels["action"]
                 usr_utt = self.usr.generate_text_from_give_semantic(
                     inputs, labels["action"], labels["emotion"])
 
             else:
-                
                 output = self.usr._parse_output(
                     self.usr._generate_action(inputs, emotion_mode=emotion_mode))
                 usr_emo = output["emotion"]
                 usr_act = self.usr._remove_illegal_action(output["action"])
                 usr_utt = output["text"]
-            r["input"].append(inputs)
-            r["golden_acts"].append(labels["action"])
-            r["golden_utts"].append(labels["text"])
-            r["golden_emotion"].append(labels["emotion"])
 
-            r["gen_acts"].append(usr_act)
-            r["gen_utts"].append(usr_utt)
-            r["gen_emotion"].append(usr_emo)
+            temp = {}
+            temp["input"] = inputs
+            temp["golden_acts"] = labels["action"]
+            temp["golden_utts"] = labels["text"]
+            temp["golden_emotion"] = labels["emotion"]
+
+            temp["gen_acts"] = usr_act
+            temp["gen_utts"] = usr_utt
+            temp["gen_emotion"] = usr_emo
 
-        return r
+            if self.use_sentiment:
+                temp["golden_sentiment"] = labels["sentiment"]
+                temp["gen_sentiment"] = output["sentiment"]
+
+            self._append_result(temp)
 
     def read_generated_result(self, f_eval):
         in_file = json.load(open(f_eval))
-        r = {
-            "input": [],
-            "golden_acts": [],
-            "golden_utts": [],
-            "golden_emotion": [],
-            "gen_acts": [],
-            "gen_utts": [],
-            "gen_emotion": []
-        }
+
         for dialog in tqdm(in_file['dialog']):
             for x in dialog:
-                r[x].append(dialog[x])
-
-        return r
+                self.r[x].append(dialog[x])
+
+    def _transform_result(self):
+        index = [x for x in self.r]
+        result = []
+        for i in range(len(self.r[index[0]])):
+            temp = {}
+            for x in index:
+                temp[x] = self.r[x][i]
+            result.append(temp)
+        return result
 
     def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False):
         if input_file:
             print("Force generation")
-            gen_r = self.generate_results(input_file, golden, no_neutral)
+            self.generate_results(input_file, golden, no_neutral)
 
         elif generated_file:
-            gen_r = self.read_generated_result(generated_file)
+            self.read_generated_result(generated_file)
         else:
             print("You must specify the input_file or the generated_file")
 
         nlg_eval = {
             "golden": golden,
             "metrics": {},
-            "dialog": []
+            "dialog": self._transform_result()
         }
-        for input, golden_act, golden_utt, golden_emo, gen_act, gen_utt, gen_emo in zip(
-                gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["golden_emotion"],
-                gen_r["gen_acts"], gen_r["gen_utts"], gen_r["gen_emotion"]):
-            nlg_eval["dialog"].append({
-                "input": input,
-                "golden_acts": golden_act,
-                "golden_utts": golden_utt,
-                "golden_emotion": golden_emo,
-                "gen_acts": gen_act,
-                "gen_utts": gen_utt,
-                "gen_emotion": gen_emo
-            })
 
         if golden:
             print("Calculate BLEU")
             bleu_metric = load_metric("sacrebleu")
-            labels = [[utt] for utt in gen_r["golden_utts"]]
+            labels = [[utt] for utt in self.r["golden_utts"]]
 
-            bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
+            bleu_score = bleu_metric.compute(predictions=self.r["gen_utts"],
                                              references=labels,
                                              force=True)
             print("bleu_metric", bleu_score)
@@ -161,7 +155,7 @@ class Evaluator:
         else:
             print("Calculate SER")
             missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
-                gen_r["gen_acts"], gen_r["gen_utts"])
+                self.r["gen_acts"], self.r["gen_utts"])
 
             print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
                 "genTUSNLG", missing, total, hallucinate, missing/total))
@@ -171,7 +165,8 @@ class Evaluator:
 
         dir_name = self.model_checkpoint
         json.dump(nlg_eval,
-                  open(os.path.join(dir_name, f"{self.time}-nlg_eval.json"), 'w'),
+                  open(os.path.join(
+                      dir_name, f"{self.time}-nlg_eval.json"), 'w'),
                   indent=2)
         return os.path.join(dir_name, f"{self.time}-nlg_eval.json")
 
@@ -232,8 +227,19 @@ class Evaluator:
             result[metric] = sum(scores[metric])/len(scores[metric])
             print(f"{metric}: {result[metric]}")
         # TODO no neutral
-        emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint, 
-        time=self.time, no_neutral=True)
+        emo_score = emotion_score(
+            golden_emotions,
+            gen_emotions,
+            self.model_checkpoint,
+            time=self.time,
+            no_neutral=False)
+        if self.use_sentiment:
+            sent_score = sentiment_score(
+                gen_file['dialog']["golden_sentiment"],
+                gen_file['dialog']["gen_sentiment"],
+                self.model_checkpoint,
+                time=self.time)
+
         # for metric in emo_score:
         #     result[metric] = emo_score[metric]
         #     print(f"{metric}: {result[metric]}")
@@ -254,7 +260,8 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra
     macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro")
     sep_f1 = metrics.f1_score(
         golden_emotions, gen_emotions, average=None, labels=labels)
-    cm = metrics.confusion_matrix(golden_emotions, gen_emotions, normalize="true", labels=labels)
+    cm = metrics.confusion_matrix(
+        golden_emotions, gen_emotions, normalize="true", labels=labels)
     disp = metrics.ConfusionMatrixDisplay(
         confusion_matrix=cm, display_labels=labels)
     disp.plot()
@@ -265,6 +272,26 @@ def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutra
     return r
 
 
+def sentiment_score(golden_sentiment, gen_sentiment, dirname=".", time=""):
+    labels = ["Neutral", "Negative", "Positive"]
+
+    print(labels)
+    macro_f1 = metrics.f1_score(
+        golden_sentiment, gen_sentiment, average="macro")
+    sep_f1 = metrics.f1_score(
+        golden_sentiment, gen_sentiment, average=None, labels=labels)
+    cm = metrics.confusion_matrix(
+        golden_sentiment, gen_sentiment, normalize="true", labels=labels)
+    disp = metrics.ConfusionMatrixDisplay(
+        confusion_matrix=cm, display_labels=labels)
+    disp.plot()
+    plt.savefig(os.path.join(dirname, f"{time}-sentiment.png"))
+    r = {"macro_f1": float(macro_f1), "sep_f1": list(
+        sep_f1), "cm": [list(c) for c in list(cm)]}
+    print(r)
+    return r
+
+
 def f1_measure(preds, labels):
     tp = 0
     score = {"precision": 0, "recall": 0, "f1": 0, "turn_acc": 0}
diff --git a/convlab/policy/emoTUS/token_map.py b/convlab/policy/emoTUS/token_map.py
index face1e35..b8b9f885 100644
--- a/convlab/policy/emoTUS/token_map.py
+++ b/convlab/policy/emoTUS/token_map.py
@@ -2,7 +2,7 @@ import json
 
 
 class tokenMap:
-    def __init__(self, tokenizer):
+    def __init__(self, tokenizer, use_sentiment=False):
         self.tokenizer = tokenizer
         self.token_name = {}
         self.hash_map = {}
@@ -10,7 +10,6 @@ class tokenMap:
         self.default()
 
     def default(self, only_action=False):
-        # TODO
         self.format_tokens = {
             'start_json': '{"emotion": "',   # 49643, 10845, 7862, 646
             'start_act': 'action": [["',              # 49329
@@ -21,6 +20,10 @@ class tokenMap:
             'end_json': '}',                 # 24303
             'end_json_2': '"}'                 # 48805
         }
+        if self.use_sentiment:
+            self.format_tokens['start_json'] = '{"sentiment": "'
+            self.format_tokens['start_emotion'] = 'emotion": "'
+
         if only_action:
             self.format_tokens['end_act'] = '"]]}'
         for token_name in self.format_tokens:
-- 
GitLab