From 14f05d86276ed36134f564c0d18a50cbb26a8986 Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Wed, 1 Feb 2023 17:35:34 +0100
Subject: [PATCH] wip

---
 convlab/policy/emoTUS/emoTUS.py       |  1 -
 convlab/policy/emoTUS/emotion_eval.py | 33 +++++++++++++++++++++------
 2 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py
index fca0f3a9..2fb0790f 100644
--- a/convlab/policy/emoTUS/emoTUS.py
+++ b/convlab/policy/emoTUS/emoTUS.py
@@ -91,7 +91,6 @@ class UserActionPolicy(GenTUSUserActionPolicy):
                 raw_output = self._generate_action(
                     raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
         output = self._parse_output(raw_output)
-        print(output)
         self.semantic_action = self._remove_illegal_action(output["action"])
         self.utterance = output["text"]
         self.emotion = output["emotion"]
diff --git a/convlab/policy/emoTUS/emotion_eval.py b/convlab/policy/emoTUS/emotion_eval.py
index 0d78f16a..da1479f1 100644
--- a/convlab/policy/emoTUS/emotion_eval.py
+++ b/convlab/policy/emoTUS/emotion_eval.py
@@ -184,8 +184,8 @@ class Evaluator:
 
         scores = {}
         for emotion in self.emotion_list:
-            if emotion == "Neutral":
-                continue
+            # if emotion == "Neutral":
+            #     continue
             scores[emotion] = {"precision": [],
                                "recall": [], "f1": [], "turn_acc": []}
             for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["Neutral_acts"]):
@@ -195,16 +195,23 @@ class Evaluator:
 
         result = {}
         for emotion in self.emotion_list:
-            if emotion == "Neutral":
-                continue
+            # if emotion == "Neutral":
+            #     continue
             result[emotion] = {}
+            for metric in scores[emotion]:
+                result[emotion][metric] = sum(
+                    scores[emotion][metric])/len(scores[emotion][metric])
             result[emotion]["bleu"] = bleu(golden_utts=r["Neutral_utts"],
                                            gen_utts=r[f"{emotion}_utts"])
             result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"],
                                          gen_acts=r[f"{emotion}_acts"])
-            for metric in scores[emotion]:
-                result[emotion][metric] = sum(
-                    scores[emotion][metric])/len(scores[emotion][metric])
+
+            result[emotion]["len"] = avg_len(gen_utts=r[f"{emotion}_utts"])
+
+            rouge_score = rouge(golden_utts=r["Neutral_utts"],
+                                gen_utts=r[f"{emotion}_utts"])
+            for metric, score in rouge_score.items():
+                result[emotion][metric] = score.mid.fmeasure
 
             print("emotion:", emotion)
             for metric in result[emotion]:
@@ -221,6 +228,11 @@ class Evaluator:
             self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'), indent=2)
 
 
+def avg_len(gen_utts):
+    n = [len(s.split()) for s in gen_utts]
+    return sum(n)/len(n)
+
+
 def bleu(golden_utts, gen_utts):
     bleu_metric = load_metric("sacrebleu")
     labels = [[utt] for utt in golden_utts]
@@ -231,6 +243,13 @@ def bleu(golden_utts, gen_utts):
     return bleu_score["score"]
 
 
+def rouge(golden_utts, gen_utts):
+    rouge_metric = load_metric("rouge")
+    rouge_score = rouge_metric.compute(predictions=gen_utts,
+                                       references=golden_utts)
+    return rouge_score
+
+
 def SER(gen_utts, gen_acts):
     missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
         gen_acts, gen_utts)
-- 
GitLab