From 85fce3cf10c319e7211dd226484678e59f4bd73a Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Mon, 16 Jan 2023 01:30:01 +0100 Subject: [PATCH] wip --- convlab/policy/emoTUS/evaluate.py | 2 +- convlab/policy/ussT5/emowoz_evaluate.py | 118 +++++++++++++++++++ convlab/policy/{uss-t5 => ussT5}/evaluate.py | 2 + convlab/policy/{uss-t5 => ussT5}/predict.py | 0 convlab/policy/{uss-t5 => ussT5}/train.py | 0 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 convlab/policy/ussT5/emowoz_evaluate.py rename convlab/policy/{uss-t5 => ussT5}/evaluate.py (97%) rename convlab/policy/{uss-t5 => ussT5}/predict.py (100%) rename convlab/policy/{uss-t5 => ussT5}/train.py (100%) diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 9b233e4f..0edc6169 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -175,7 +175,7 @@ class Evaluator: # TODO add emotion force_prediction = True if generated_file: - print("use generated file") + print("---> use generated file") gen_file = json.load(open(generated_file)) force_prediction = False if gen_file["golden"]: diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py new file mode 100644 index 00000000..fcd42099 --- /dev/null +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -0,0 +1,118 @@ +import json +import os +from argparse import ArgumentParser +from datetime import datetime + +import matplotlib.pyplot as plt +import pandas as pd +from sklearn import metrics +from tqdm import tqdm +from transformers import T5ForConditionalGeneration, T5Tokenizer + +from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset +from convlab.policy.ussT5.evaluate import tri_convert + + +def arg_parser(): + parser = ArgumentParser() + parser.add_argument("--model", type=str, default="", + help="model name") + parser.add_argument("--data", default="emowoz+dialmage", type=str) + parser.add_argument("--gen-file", type=str) + parser.add_argument("--stop", default=-1) + return parser.parse_args() + + +def build_data(data): + sentiments = [] + for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items(): + sentiments[int(index)] = sentiment + data = {"input_text": [], "target_text": []} + prefix = "satisfaction score: " + for d in data: + utt = "" + turn_len = len(d["turns"]) + for index, turn in enumerate(d["turns"]): + if turn["speaker"] == "user": + if index == turn_len - 1: + continue + if index == 0: + utt = prefix + turn["utterance"] + else: + utt += ' ' + turn["utterance"] + else: + if index == 0: + print("this should no happen (index == 0)") + utt = prefix + turn["utterance"] + if index == turn_len - 1: + print("this should no happen (index == turn_len - 1)") + continue + + utt += ' ' + turn["utterance"] + + data["input_text"].append(utt) + data["target_text"].append( + sentiments[d["turns"][index+1]["emotion"][-1]["sentiment"]]) + return data + + +def generate_result(model_checkpoint, data, stop=-1): + tokenizer = T5Tokenizer.from_pretrained(model_checkpoint) + model = T5ForConditionalGeneration.from_pretrained(model_checkpoint) + data = pd.read_csv(data, index_col=False).astype(str) + results = [] + i = 0 + for input_text, target_text in tqdm(zip(data["input_text"], data["target_text"]), ascii=True): + if stop > 0 and i > stop: + break + if "satisfaction score" in input_text: + i += 1 + inputs = tokenizer([input_text], return_tensors="pt", padding=True) + output = model.generate(input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + do_sample=False)[0] + output = tokenizer.batch_decode(output, skip_special_tokens=True) + if len(output) > 1: + print(output) + output = "illegal" + + results.append({"input_text": input_text, + "preds": tri_convert(output), + "label": target_text}) + json.dump(results, open(os.path.join( + model_checkpoint, "emowoz_result.json"), 'w')) + return results + + +def read_result(result): + preds = [] + label = [] + for r in result: + preds.append(r[preds]) + label.append(r[label]) + return preds, label + + +def main(): + args = arg_parser() + if args.gen_file: + preds, label = read_result(json.load(open(args.gen_file))) + else: + results = generate_result(args.model, args.data, args.stop) + preds, label = read_result(results) + all_sentiment = ["Neutral", "Negative", "Positive"] + print(all_sentiment) + tri_f1 = metrics.f1_score(label, preds, average="macro") + sep_f1 = metrics.f1_score(label, preds, average=None, labels=all_sentiment) + cm = metrics.confusion_matrix( + label, preds, normalize="true", labels=all_sentiment) + disp = metrics.ConfusionMatrixDisplay( + confusion_matrix=cm, + display_labels=['1', '2', '3', '4', '5']) + disp.plot() + r = {"tri_f1": float(tri_f1), + "sep_f1": list(sep_f1), + "cm": [list(c) for c in list(cm)]} + print(r) + time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" + plt.savefig(os.path.join(args.model, f"{time}-emowoz.png")) diff --git a/convlab/policy/uss-t5/evaluate.py b/convlab/policy/ussT5/evaluate.py similarity index 97% rename from convlab/policy/uss-t5/evaluate.py rename to convlab/policy/ussT5/evaluate.py index 5475cbef..67e166ee 100644 --- a/convlab/policy/uss-t5/evaluate.py +++ b/convlab/policy/ussT5/evaluate.py @@ -84,6 +84,8 @@ def generate_result(model_checkpoint, data): results.append({"input_text": input_text, "preds": output, "label": target_text}) + json.dump(results, open(os.path.join( + model_checkpoint, "uss_result.json"), 'w')) return results diff --git a/convlab/policy/uss-t5/predict.py b/convlab/policy/ussT5/predict.py similarity index 100% rename from convlab/policy/uss-t5/predict.py rename to convlab/policy/ussT5/predict.py diff --git a/convlab/policy/uss-t5/train.py b/convlab/policy/ussT5/train.py similarity index 100% rename from convlab/policy/uss-t5/train.py rename to convlab/policy/ussT5/train.py -- GitLab