diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py index 4e7f038d29425f21299c2a3e9ab8158401cdc3a8..b5ba1cfad2f149bfc48c4333d0f9a83c95c3c30d 100644 --- a/convlab/policy/emoTUS/self_bleu.py +++ b/convlab/policy/emoTUS/self_bleu.py @@ -9,6 +9,7 @@ def arg_parser(): parser = argparse.ArgumentParser() parser.add_argument("--file", type=str) parser.add_argument("--fast-bleu", action="store_true") + parser.add_argument("--uss", action="store_true") return parser.parse_args() @@ -17,13 +18,17 @@ def read_file(file_name): return nlg_candidates -def get_sent(candidates, bleu_mode="torch"): +def get_sent(candidates, bleu_mode="torch", uss=False): if bleu_mode == "torch": + if uss: + return [x["preds"] for x in candidates] if "log" in candidates: return [x["gen_utts"] for x in candidates["log"]] else: return [x["gen_utts"] for x in candidates["dialog"]] else: + if uss: + return [x["preds"].split() for x in candidates] if "log" in candidates: return [x["gen_utts"].split() for x in candidates["log"]] else: @@ -41,20 +46,22 @@ def SelfBLEU(sentences): return sum(result)/len(result) -def calculate(candidates, bleu_mode="torch"): - sentences = get_sent(candidates, bleu_mode) +def calculate(candidates, bleu_mode="torch", uss=False): + sentences = get_sent(candidates, bleu_mode, uss) if bleu_mode == "torch": x = SelfBLEU(sentences) else: bleu = fast_bleu.SelfBLEU(sentences) x = bleu.get_score() # x = bleu.get_score() + # print(x) print(sum(x[4])/len(x[4])) + if __name__ == "__main__": args = arg_parser() if args.fast_bleu: import fast_bleu - calculate(read_file(args.file), "fast-bleu") + calculate(read_file(args.file), "fast-bleu", uss=args.uss) else: - calculate(read_file(args.file)) + calculate(read_file(args.file), uss=args.uss) diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py index 644b61434b722d22b1f82489e85a0c471701afc6..2e8a07d5a20bf50bde5dd2f103385a490de3f4a7 100644 --- a/convlab/policy/ussT5/emowoz_evaluate.py +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -2,6 +2,7 @@ import json import os from argparse import ArgumentParser from datetime import datetime +import numpy as np import matplotlib.pyplot as plt import pandas as pd @@ -12,6 +13,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset from convlab.policy.ussT5.evaluate import tri_convert +from datasets import load_metric + def arg_parser(): parser = ArgumentParser() @@ -104,29 +107,30 @@ def generate_result(model_checkpoint, data, stop=-1): def read_result(result): - preds = [] - label = [] + d = {} + for d_name in ["satisfaction score", "utterance generation", "action prediction"]: + d[d_name] = {"preds": [], "label": []} for r in result: - if "satisfaction score" in r["input_text"]: - preds.append(r["preds"]) - label.append(r["label"]) - return preds, label + for d_name in ["satisfaction score", "utterance generation", "action prediction"]: + if d_name in r["input_text"]: + d[d_name]["preds"].append(r["preds"]) + d[d_name]["label"].append(r["label"]) + return d -def main(): - args = arg_parser() - if args.gen_file: - preds, label = read_result(json.load(open(args.gen_file))) - else: - data = build_data(load_experiment_dataset(args.data)["test"]) - results = generate_result(args.model, data, args.stop) - preds, label = read_result(results) +def satisfaction(model, d): + # satisfaction all_sentiment = ["Neutral", "Negative", "Positive"] print(all_sentiment) - tri_f1 = metrics.f1_score(label, preds, average="macro") - sep_f1 = metrics.f1_score(label, preds, average=None, labels=all_sentiment) + tri_f1 = metrics.f1_score( + d["satisfaction score"]["label"], + d["satisfaction score"]["preds"], average="macro") + sep_f1 = metrics.f1_score( + d["satisfaction score"]["label"], + d["satisfaction score"]["preds"], average=None, labels=all_sentiment) cm = metrics.confusion_matrix( - label, preds, normalize="true", labels=all_sentiment) + d["satisfaction score"]["label"], + d["satisfaction score"]["preds"], normalize="true", labels=all_sentiment) disp = metrics.ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=all_sentiment) @@ -136,7 +140,66 @@ def main(): "cm": [list(c) for c in list(cm)]} print(r) time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" - plt.savefig(os.path.join(args.model, f"{time}-emowoz.png")) + plt.savefig(os.path.join(model, f"{time}-emowoz.png")) + + +def utterance(model, d): + bleu_metric = load_metric("sacrebleu") + labels = [[utt] for utt in d["utterance generation"]["label"]] + + bleu_score = bleu_metric.compute( + predictions=d["utterance generation"]["preds"], + references=labels, + force=True) + print(f"{model} bleu_score", bleu_score) + + +def action(model, d): + score = {} + for preds, label in zip(d["action prediction"]["preds"], d["action prediction"]["label"]): + s = f1_score(preds, label) + for n, v in s.items(): + if n not in score: + score[n] = [] + score[n].append(v) + print(f"{model} action") + for n, v in score.items(): + print(n, np.mean(v)) + + +def f1_score(prediction, label): + score = {} + tp = 0 + pre = prediction.split(',') + lab = label.split(',') + for p in pre: + if p in lab: + tp += 1 + score["precision"] = tp/len(pre) + score["recall"] = tp/len(lab) + score["F1"] = 0 + if score["precision"]+score["recall"] > 0: + score["F1"] = 2*score["precision"]*score["recall"] / \ + (score["precision"]+score["recall"]) + if pre == lab: + score["acc"] = 1 + else: + score["acc"] = 0 + return score + + +def main(): + args = arg_parser() + if args.gen_file: + d = read_result(json.load(open(args.gen_file))) + else: + data = build_data(load_experiment_dataset(args.data)["test"]) + results = generate_result(args.model, data, args.stop) + d = read_result(results) + model = args.model + satisfaction(model, d) + utterance(model, d) + action(model, d) if __name__ == "__main__":