From f48c40ebe36f90c145b901fa36f904728d45479b Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Sun, 5 Feb 2023 23:02:43 +0100
Subject: [PATCH] wip

---
 convlab/policy/emoTUS/analysis.py     |  70 ++++++----
 convlab/policy/emoTUS/emoTUS.py       |   5 +-
 convlab/policy/emoTUS/emotion_eval.py |   2 +-
 convlab/policy/emoTUS/evaluate.py     | 183 +++++++++++++-------------
 4 files changed, 144 insertions(+), 116 deletions(-)

diff --git a/convlab/policy/emoTUS/analysis.py b/convlab/policy/emoTUS/analysis.py
index 8ac60c74..75b55eb2 100644
--- a/convlab/policy/emoTUS/analysis.py
+++ b/convlab/policy/emoTUS/analysis.py
@@ -1,8 +1,12 @@
+import json
+import os
 from argparse import ArgumentParser
+
+import matplotlib.pyplot as plt
 import numpy as np
-import json
 import pandas as pd
-import matplotlib.pyplot as plt
+
+result_dir = "convlab/policy/emoTUS/result"
 
 
 def arg_parser():
@@ -54,51 +58,69 @@ def get_turn_emotion(conversation):
                     insert_turn(turn_info[metric], turn, emotion)
                 else:
                     insert_turn(turn_info[f"Not {metric}"], turn, emotion)
-    data = {'x': [t for t in range(turn)], 'all_positive': [
-    ], 'all_negative': [], 'all_mean': []}
+    print("MAX_TURN", max_turn)
+    data = {'x': [t for t in range(max_turn)],
+            'all_positive': [],
+            'all_negative': [],
+            'all_mean': [],
+            'all_std': []}
+
     for metric in ["Complete", "Success", "Success strict"]:
         data[f"{metric}_positive"] = []
         data[f"{metric}_negative"] = []
         data[f"{metric}_mean"] = []
+        data[f"{metric}_std"] = []
         data[f"Not {metric}_positive"] = []
         data[f"Not {metric}_negative"] = []
         data[f"Not {metric}_mean"] = []
+        data[f"Not {metric}_std"] = []
 
     for t in range(turn):
-        pos, neg, mean = turn_score(turn_info["all"][t])
+        pos, neg, mean, std = turn_score(turn_info["all"][t])
         data[f"all_positive"].append(pos)
         data[f"all_negative"].append(neg)
         data[f"all_mean"].append(mean)
+        data[f"all_std"].append(std)
         for raw_metric in ["Complete", "Success", "Success strict"]:
             for metric in [raw_metric, f"Not {raw_metric}"]:
                 if t not in turn_info[metric]:
                     data[f"{metric}_positive"].append(0)
                     data[f"{metric}_negative"].append(0)
                     data[f"{metric}_mean"].append(0)
+                    data[f"{metric}_std"].append(0)
                 else:
-                    pos, neg, mean = turn_score(turn_info[metric][t])
+                    pos, neg, mean, std = turn_score(turn_info[metric][t])
                     data[f"{metric}_positive"].append(pos)
                     data[f"{metric}_negative"].append(neg)
                     data[f"{metric}_mean"].append(mean)
+                    data[f"{metric}_std"].append(std)
+    for x in data:
+        data[x] = np.array(data[x])
 
     fig, ax = plt.subplots()
-    ax.plot(data['x'], data["Complete_mean"],
-            'o--', color='C0', label="Complete")
-    ax.fill_between(data['x'], data["Complete_positive"],
-                    data["Complete_negative"], color='C0', alpha=0.2)
-    ax.plot(data['x'], data["Not Complete_mean"],
-            'o--', color='C1', label="Not Complete")
-    ax.fill_between(data['x'], data["Not Complete_positive"],
-                    data["Not Complete_negative"], color='C1', alpha=0.2)
-    ax.plot(data['x'], data["all_mean"], 'o--', color='C2',
-            label="All")
-    ax.fill_between(data['x'], data["all_positive"],
-                    data["all_negative"], color='C2', alpha=0.2)
+    p = {"Complete": {"color": "C0", "label": "Success"},
+         "Not Complete": {"color": "C1", "label": "Fail"},
+         "all": {"color": "C2", "label": "all"}}
+    for name, para in p.items():
+
+        ax.plot(data['x'],
+                data[f"{name}_mean"],
+                'o--',
+                color=para["color"],
+                label=para["label"])
+        ax.fill_between(data['x'],
+                        data[f"{name}_mean"]+data[f"{name}_std"],
+                        data[f"{name}_mean"]-data[f"{name}_std"],
+                        color=para["color"], alpha=0.2)
+
     ax.legend()
     ax.set_xlabel("turn")
     ax.set_ylabel("Sentiment")
+    ax.set_xticks([t for t in range(0, max_turn, 2)])
+    plt.grid(axis='x', color='0.95')
+    plt.grid(axis='y', color='0.95')
     # plt.show()
-    plt.savefig("convlab/policy/emoTUS/fig.png")
+    plt.savefig(os.path.join(result_dir, "turn2emotion.png"))
 
 
 def turn_score(score_list):
@@ -110,7 +132,7 @@ def turn_score(score_list):
             positive += 1
         if s < 0:
             negative += -1
-    return positive/count, negative/count, np.mean(score_list)
+    return positive/count, negative/count, np.mean(score_list), np.std(score_list, ddof=1)/np.sqrt(len(score_list))
 
 
 def insert_turn(turn_info, turn, emotion):
@@ -253,13 +275,15 @@ def dict2csv(data):
         r[act] = temp
     dataframe = pd.DataFrame.from_dict(
         r, orient='index', columns=[emo for emo in emotion]+["count"])
-    dataframe.to_csv(open("convlab/policy/emoTUS/act2emotion.csv", 'w'))
+    dataframe.to_csv(open(os.path.join(result_dir, "act2emotion.csv"), 'w'))
 
 
 def main():
     args = arg_parser()
     result = {}
-    conversation = json.load(open(args.file))
+    if not os.path.exists(result_dir):
+        os.makedirs(result_dir)
+    conversation = json.load(open(args.file))["conversation"]
     basic_info = basic_analysis(conversation)
     result["basic_info"] = basic_info
     print(basic_info)
@@ -267,7 +291,7 @@ def main():
     print(advance_info)
     result["advance_info"] = advance_info
     json.dump(result, open(
-        "convlab/policy/emoTUS/conversation_result.json", 'w'), indent=2)
+        os.path.join("conversation_result.json"), 'w'), indent=2)
     dict2csv(advance_info)
     get_turn_emotion(conversation)
 
diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py
index 2fb0790f..8d10e765 100644
--- a/convlab/policy/emoTUS/emoTUS.py
+++ b/convlab/policy/emoTUS/emoTUS.py
@@ -235,12 +235,11 @@ class UserActionPolicy(GenTUSUserActionPolicy):
 
         return text
 
-    def generate_from_emotion(self, raw_inputs,  emotion=None, mode="max", allow_general_intent=True):
+    def generate_from_emotion(self, raw_inputs, emotion=None, mode="max", allow_general_intent=True):
         self.kg.parse_input(raw_inputs)
         model_input = self.vector.encode(raw_inputs, self.max_in_len)
         responses = {}
         if emotion:
-            print("if emotion")
             emotion_list = [emotion]
         else:
             emotion_list = self.emotion_list
@@ -285,6 +284,8 @@ class UserActionPolicy(GenTUSUserActionPolicy):
         self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
         pos = self._update_seq([0], 0)
         pos = self._update_seq(self.token_map.get_id('start_json'), pos)
+        pos = self._update_seq(
+            self.token_map.get_id('start_emotion'), pos)
         pos = self._update_seq(self.kg._get_token_id(emotion), pos)
         pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
         pos = self._update_seq(self.token_map.get_id('start_act'), pos)
diff --git a/convlab/policy/emoTUS/emotion_eval.py b/convlab/policy/emoTUS/emotion_eval.py
index da1479f1..be0c383c 100644
--- a/convlab/policy/emoTUS/emotion_eval.py
+++ b/convlab/policy/emoTUS/emotion_eval.py
@@ -41,7 +41,7 @@ class Evaluator:
 
         self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
         self.use_sentiment = kwargs.get("use_sentiment", False)
-        self.add_persona = kwargs.get("add_persona", False)
+        self.add_persona = kwargs.get("add_persona", True)
         self.emotion_mid = kwargs.get("emotion_mid", False)
         weight = kwargs.get("weight", None)
         self.sample = kwargs.get("sample", False)
diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py
index 358a4828..60a15e05 100644
--- a/convlab/policy/emoTUS/evaluate.py
+++ b/convlab/policy/emoTUS/evaluate.py
@@ -29,8 +29,10 @@ def arg_parser():
     parser.add_argument("--generated-file", type=str, help="the generated results",
                         default="")
     parser.add_argument("--dataset", default="multiwoz")
-    parser.add_argument("--do-golden-nlg", action="store_true",
-                        help="do golden nlg generation")
+    parser.add_argument("--golden-emotion", action="store_true",
+                        help="golden emotion -> action + utt")
+    parser.add_argument("--golden-action", action="store_true",
+                        help="golden emotion + action -> utt")
     parser.add_argument("--use-sentiment", action="store_true")
     parser.add_argument("--emotion-mid", action="store_true")
     parser.add_argument("--weight", type=float, default=None)
@@ -83,23 +85,29 @@ class Evaluator:
         for x in self.r:
             self.r[x].append(temp[x])
 
-    def generate_results(self, f_eval, golden=False):
+    def generate_results(self, f_eval, golden_emotion=False, golden_action=False):
         emotion_mode = "normal"
         in_file = json.load(open(f_eval))
-
+        mode = "max"
+        if self.sample:
+            mode = "sample"
         for dialog in tqdm(in_file['dialog']):
             inputs = dialog["in"]
             labels = self.usr._parse_output(dialog["out"])
 
-            if golden:
+            if golden_action:
                 usr_act = labels["action"]
+                usr_emo = labels["emotion"]
                 usr_utt = self.usr.generate_text_from_give_semantic(
                     inputs, labels["action"], labels["emotion"])
-
+            elif golden_emotion:
+                usr_emo = labels["emotion"]
+                output = self.usr.generate_from_emotion(
+                    inputs,  emotion=usr_emo, mode=mode)
+                output = self.usr._parse_output(output[usr_emo])
+                usr_act = self.usr._remove_illegal_action(output["action"])
+                usr_utt = output["text"]
             else:
-                mode = "max"
-                if self.sample:
-                    mode = "sample"
                 output = self.usr._parse_output(
                     self.usr._generate_action(inputs, mode=mode, emotion_mode=emotion_mode))
                 usr_emo = output["emotion"]
@@ -139,10 +147,10 @@ class Evaluator:
             result.append(temp)
         return result
 
-    def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
+    def nlg_evaluation(self, input_file=None, generated_file=None, golden_emotion=False, golden_action=False):
         if input_file:
             print("Force generation")
-            self.generate_results(input_file, golden)
+            self.generate_results(input_file, golden_emotion, golden_action)
 
         elif generated_file:
             self.read_generated_result(generated_file)
@@ -152,34 +160,40 @@ class Evaluator:
         if self.sample:
             mode = "sample"
 
-        nlg_eval = {
-            "golden": golden,
-            "mode": mode,
-            "metrics": {},
-            "dialog": self._transform_result()
-        }
+        nlg_eval = {}
+        if golden_action:
+            nlg_eval["golden"] = "golden_action"
+        elif golden_emotion:
+            nlg_eval["golden"] = "golden_emotion"
+        else:
+            nlg_eval["golden"] = False
 
-        if golden:
-            print("Calculate BLEU")
-            bleu_metric = load_metric("sacrebleu")
-            labels = [[utt] for utt in self.r["golden_utts"]]
+        nlg_eval["mode"] = mode
+        nlg_eval["metrics"] = {}
+        nlg_eval["dialog"] = self._transform_result()
 
-            bleu_score = bleu_metric.compute(predictions=self.r["gen_utts"],
-                                             references=labels,
-                                             force=True)
-            print("bleu_metric", bleu_score)
-            nlg_eval["metrics"]["bleu"] = bleu_score
+        # if golden_action:
+        print("Calculate BLEU")
+        bleu_metric = load_metric("sacrebleu")
+        labels = [[utt] for utt in self.r["golden_utts"]]
 
-        else:
-            print("Calculate SER")
-            missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
-                self.r["gen_acts"], self.r["gen_utts"])
+        bleu_score = bleu_metric.compute(predictions=self.r["gen_utts"],
+                                         references=labels,
+                                         force=True)
+        print("bleu_metric", bleu_score)
+        nlg_eval["metrics"]["bleu"] = bleu_score
 
-            print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
-                "genTUSNLG", missing, total, hallucinate, missing/total))
-            nlg_eval["metrics"]["SER"] = missing/total
+        # else:
+        print("Calculate SER")
+        missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
+            self.r["gen_acts"], self.r["gen_utts"])
 
-            # TODO emotion metric
+        print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
+            "EmoUSNLG", missing, total, hallucinate, missing/total))
+        print(nlg_eval["metrics"])
+        nlg_eval["metrics"]["SER"] = missing/total
+
+        # TODO emotion metric
 
         dir_name = self.model_checkpoint
         json.dump(nlg_eval,
@@ -188,42 +202,23 @@ class Evaluator:
                   indent=2)
         return os.path.join(dir_name, f"{self.time}-nlg_eval.json")
 
-    def evaluation(self, input_file=None, generated_file=None):
+    def evaluation(self, generated_file, golden_emotion=False, golden_action=False):
         # TODO add emotion
-        force_prediction = True
-        if generated_file:
-            print("---> use generated file")
-            gen_file = json.load(open(generated_file))
-            force_prediction = False
-            if gen_file["golden"]:
-                force_prediction = True
-            self.read_generated_result(generated_file)
+        gen_file = json.load(open(generated_file))
+        self.read_generated_result(generated_file)
 
-        if force_prediction:
-            in_file = json.load(open(input_file))
-            dialog_result = []
+        if golden_action:
+            print("golden_action, skip semantic evaluation")
+            return
+
+        elif golden_emotion:
+            print("golden_emotion, skip emotion evaluation")
             gen_acts, golden_acts = [], []
-            # scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
-            for dialog in tqdm(in_file['dialog']):
-                inputs = dialog["in"]
-                labels = self.usr._parse_output(dialog["out"])
-                ans_action = self.usr._remove_illegal_action(labels["action"])
-                preds = self.usr._generate_action(inputs)
-                preds = self.usr._parse_output(preds)
-                usr_action = self.usr._remove_illegal_action(preds["action"])
-
-                gen_acts.append(usr_action)
-                golden_acts.append(ans_action)
-
-                d = {"input": inputs,
-                     "golden_acts": ans_action,
-                     "gen_acts": usr_action}
-                if "text" in preds:
-                    d["golden_utts"] = labels["text"]
-                    d["gen_utts"] = preds["text"]
-                    # print("pred text", preds["text"])
-
-                dialog_result.append(d)
+            for dialog in gen_file['dialog']:
+                gen_acts.append(dialog["gen_acts"])
+                golden_acts.append(dialog["golden_acts"])
+            dialog_result = gen_file['dialog']
+
         else:
             gen_acts, golden_acts = [], []
             gen_emotions, golden_emotions = [], []
@@ -246,27 +241,33 @@ class Evaluator:
             result[metric] = sum(scores[metric])/len(scores[metric])
             print(f"{metric}: {result[metric]}")
 
-        emo_score = emotion_score(
-            golden_emotions,
-            gen_emotions,
-            self.model_checkpoint,
-            time=self.time,
-            no_neutral=False)
-        if self.use_sentiment:
-            sent_score = sentiment_score(
-                self.r["golden_sentiment"],
-                self.r["gen_sentiment"],
-                self.model_checkpoint,
-                time=self.time)
-        else:
-            # transfer emotions to sentiment if the model do not generate sentiment
-            golden_sentiment = [self.emo2sent[emo] for emo in golden_emotions]
-            gen_sentiment = [self.emo2sent[emo] for emo in gen_emotions]
-            sent_score = sentiment_score(
-                golden_sentiment,
-                gen_sentiment,
+        if not golden_emotion:
+            emo_score = emotion_score(
+                golden_emotions,
+                gen_emotions,
                 self.model_checkpoint,
-                time=self.time)
+                time=self.time,
+                no_neutral=False)
+            result["emotion"] = {"macro_f1": emo_score["macro_f1"],
+                                 "sep_f1": emo_score["sep_f1"]}
+            if self.use_sentiment:
+                sent_score = sentiment_score(
+                    self.r["golden_sentiment"],
+                    self.r["gen_sentiment"],
+                    self.model_checkpoint,
+                    time=self.time)
+            else:
+                # transfer emotions to sentiment if the model do not generate sentiment
+                golden_sentiment = [self.emo2sent[emo]
+                                    for emo in golden_emotions]
+                gen_sentiment = [self.emo2sent[emo] for emo in gen_emotions]
+                sent_score = sentiment_score(
+                    golden_sentiment,
+                    gen_sentiment,
+                    self.model_checkpoint,
+                    time=self.time)
+            result["sentiment"] = {"macro_f1": sent_score["macro_f1"],
+                                   "sep_f1": sent_score["sep_f1"]}
 
         # for metric in emo_score:
         #     result[metric] = emo_score[metric]
@@ -356,11 +357,13 @@ def main():
         else:
             nlg_result = eval.nlg_evaluation(input_file=args.input_file,
                                              generated_file=args.generated_file,
-                                             golden=args.do_golden_nlg)
+                                             golden_emotion=args.golden_emotion,
+                                             golden_action=args.golden_action)
 
             generated_file = nlg_result
-        eval.evaluation(args.input_file,
-                        generated_file)
+        eval.evaluation(generated_file,
+                        golden_emotion=args.golden_emotion,
+                        golden_action=args.golden_action)
 
 
 if __name__ == '__main__':
-- 
GitLab