From 85fce3cf10c319e7211dd226484678e59f4bd73a Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Mon, 16 Jan 2023 01:30:01 +0100
Subject: [PATCH] wip

---
 convlab/policy/emoTUS/evaluate.py            |   2 +-
 convlab/policy/ussT5/emowoz_evaluate.py      | 118 +++++++++++++++++++
 convlab/policy/{uss-t5 => ussT5}/evaluate.py |   2 +
 convlab/policy/{uss-t5 => ussT5}/predict.py  |   0
 convlab/policy/{uss-t5 => ussT5}/train.py    |   0
 5 files changed, 121 insertions(+), 1 deletion(-)
 create mode 100644 convlab/policy/ussT5/emowoz_evaluate.py
 rename convlab/policy/{uss-t5 => ussT5}/evaluate.py (97%)
 rename convlab/policy/{uss-t5 => ussT5}/predict.py (100%)
 rename convlab/policy/{uss-t5 => ussT5}/train.py (100%)

diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py
index 9b233e4f..0edc6169 100644
--- a/convlab/policy/emoTUS/evaluate.py
+++ b/convlab/policy/emoTUS/evaluate.py
@@ -175,7 +175,7 @@ class Evaluator:
         # TODO add emotion
         force_prediction = True
         if generated_file:
-            print("use generated file")
+            print("---> use generated file")
             gen_file = json.load(open(generated_file))
             force_prediction = False
             if gen_file["golden"]:
diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py
new file mode 100644
index 00000000..fcd42099
--- /dev/null
+++ b/convlab/policy/ussT5/emowoz_evaluate.py
@@ -0,0 +1,118 @@
+import json
+import os
+from argparse import ArgumentParser
+from datetime import datetime
+
+import matplotlib.pyplot as plt
+import pandas as pd
+from sklearn import metrics
+from tqdm import tqdm
+from transformers import T5ForConditionalGeneration, T5Tokenizer
+
+from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
+from convlab.policy.ussT5.evaluate import tri_convert
+
+
+def arg_parser():
+    parser = ArgumentParser()
+    parser.add_argument("--model", type=str, default="",
+                        help="model name")
+    parser.add_argument("--data", default="emowoz+dialmage", type=str)
+    parser.add_argument("--gen-file", type=str)
+    parser.add_argument("--stop", default=-1)
+    return parser.parse_args()
+
+
+def build_data(data):
+    sentiments = []
+    for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items():
+        sentiments[int(index)] = sentiment
+    data = {"input_text": [], "target_text": []}
+    prefix = "satisfaction score: "
+    for d in data:
+        utt = ""
+        turn_len = len(d["turns"])
+        for index, turn in enumerate(d["turns"]):
+            if turn["speaker"] == "user":
+                if index == turn_len - 1:
+                    continue
+                if index == 0:
+                    utt = prefix + turn["utterance"]
+                else:
+                    utt += ' ' + turn["utterance"]
+            else:
+                if index == 0:
+                    print("this should no happen (index == 0)")
+                    utt = prefix + turn["utterance"]
+                if index == turn_len - 1:
+                    print("this should no happen (index == turn_len - 1)")
+                    continue
+
+                utt += ' ' + turn["utterance"]
+
+                data["input_text"].append(utt)
+                data["target_text"].append(
+                    sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]])
+    return data
+
+
+def generate_result(model_checkpoint, data, stop=-1):
+    tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
+    model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
+    data = pd.read_csv(data, index_col=False).astype(str)
+    results = []
+    i = 0
+    for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True):
+        if stop > 0 and i > stop:
+            break
+        if "satisfaction score" in input_text:
+            i += 1
+            inputs = tokenizer([input_text], return_tensors="pt", padding=True)
+            output = model.generate(input_ids=inputs["input_ids"],
+                                    attention_mask=inputs["attention_mask"],
+                                    do_sample=False)[0]
+            output = tokenizer.batch_decode(output, skip_special_tokens=True)
+            if len(output) > 1:
+                print(output)
+                output = "illegal"
+
+            results.append({"input_text": input_text,
+                            "preds": tri_convert(output),
+                            "label": target_text})
+    json.dump(results, open(os.path.join(
+        model_checkpoint, "emowoz_result.json"), 'w'))
+    return results
+
+
+def read_result(result):
+    preds = []
+    label = []
+    for r in result:
+        preds.append(r[preds])
+        label.append(r[label])
+    return preds, label
+
+
+def main():
+    args = arg_parser()
+    if args.gen_file:
+        preds, label = read_result(json.load(open(args.gen_file)))
+    else:
+        results = generate_result(args.model, args.data, args.stop)
+        preds, label = read_result(results)
+    all_sentiment = ["Neutral", "Negative", "Positive"]
+    print(all_sentiment)
+    tri_f1 = metrics.f1_score(label, preds, average="macro")
+    sep_f1 = metrics.f1_score(label, preds, average=None, labels=all_sentiment)
+    cm = metrics.confusion_matrix(
+        label, preds, normalize="true", labels=all_sentiment)
+    disp = metrics.ConfusionMatrixDisplay(
+        confusion_matrix=cm,
+        display_labels=['1', '2', '3', '4', '5'])
+    disp.plot()
+    r = {"tri_f1": float(tri_f1),
+         "sep_f1": list(sep_f1),
+         "cm": [list(c) for c in list(cm)]}
+    print(r)
+    time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
+    plt.savefig(os.path.join(args.model, f"{time}-emowoz.png"))
diff --git a/convlab/policy/uss-t5/evaluate.py b/convlab/policy/ussT5/evaluate.py
similarity index 97%
rename from convlab/policy/uss-t5/evaluate.py
rename to convlab/policy/ussT5/evaluate.py
index 5475cbef..67e166ee 100644
--- a/convlab/policy/uss-t5/evaluate.py
+++ b/convlab/policy/ussT5/evaluate.py
@@ -84,6 +84,8 @@ def generate_result(model_checkpoint, data):
             results.append({"input_text": input_text,
                             "preds": output,
                             "label": target_text})
+    json.dump(results, open(os.path.join(
+        model_checkpoint, "uss_result.json"), 'w'))
     return results
 
 
diff --git a/convlab/policy/uss-t5/predict.py b/convlab/policy/ussT5/predict.py
similarity index 100%
rename from convlab/policy/uss-t5/predict.py
rename to convlab/policy/ussT5/predict.py
diff --git a/convlab/policy/uss-t5/train.py b/convlab/policy/ussT5/train.py
similarity index 100%
rename from convlab/policy/uss-t5/train.py
rename to convlab/policy/ussT5/train.py
-- 
GitLab