Skip to content
Snippets Groups Projects
Commit 15b002b0 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 2c5ddaec
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str)
parser.add_argument("--fast-bleu", action="store_true")
parser.add_argument("--uss", action="store_true")
return parser.parse_args()
......@@ -17,13 +18,17 @@ def read_file(file_name):
return nlg_candidates
def get_sent(candidates, bleu_mode="torch"):
def get_sent(candidates, bleu_mode="torch", uss=False):
if bleu_mode == "torch":
if uss:
return [x["preds"] for x in candidates]
if "log" in candidates:
return [x["gen_utts"] for x in candidates["log"]]
else:
return [x["gen_utts"] for x in candidates["dialog"]]
else:
if uss:
return [x["preds"].split() for x in candidates]
if "log" in candidates:
return [x["gen_utts"].split() for x in candidates["log"]]
else:
......@@ -41,20 +46,22 @@ def SelfBLEU(sentences):
return sum(result)/len(result)
def calculate(candidates, bleu_mode="torch"):
sentences = get_sent(candidates, bleu_mode)
def calculate(candidates, bleu_mode="torch", uss=False):
sentences = get_sent(candidates, bleu_mode, uss)
if bleu_mode == "torch":
x = SelfBLEU(sentences)
else:
bleu = fast_bleu.SelfBLEU(sentences)
x = bleu.get_score()
# x = bleu.get_score()
# print(x)
print(sum(x[4])/len(x[4]))
if __name__ == "__main__":
args = arg_parser()
if args.fast_bleu:
import fast_bleu
calculate(read_file(args.file), "fast-bleu")
calculate(read_file(args.file), "fast-bleu", uss=args.uss)
else:
calculate(read_file(args.file))
calculate(read_file(args.file), uss=args.uss)
......@@ -2,6 +2,7 @@ import json
import os
from argparse import ArgumentParser
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
......@@ -12,6 +13,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
from convlab.policy.ussT5.evaluate import tri_convert
from datasets import load_metric
def arg_parser():
parser = ArgumentParser()
......@@ -104,29 +107,30 @@ def generate_result(model_checkpoint, data, stop=-1):
def read_result(result):
preds = []
label = []
d = {}
for d_name in ["satisfaction score", "utterance generation", "action prediction"]:
d[d_name] = {"preds": [], "label": []}
for r in result:
if "satisfaction score" in r["input_text"]:
preds.append(r["preds"])
label.append(r["label"])
return preds, label
for d_name in ["satisfaction score", "utterance generation", "action prediction"]:
if d_name in r["input_text"]:
d[d_name]["preds"].append(r["preds"])
d[d_name]["label"].append(r["label"])
return d
def main():
args = arg_parser()
if args.gen_file:
preds, label = read_result(json.load(open(args.gen_file)))
else:
data = build_data(load_experiment_dataset(args.data)["test"])
results = generate_result(args.model, data, args.stop)
preds, label = read_result(results)
def satisfaction(model, d):
# satisfaction
all_sentiment = ["Neutral", "Negative", "Positive"]
print(all_sentiment)
tri_f1 = metrics.f1_score(label, preds, average="macro")
sep_f1 = metrics.f1_score(label, preds, average=None, labels=all_sentiment)
tri_f1 = metrics.f1_score(
d["satisfaction score"]["label"],
d["satisfaction score"]["preds"], average="macro")
sep_f1 = metrics.f1_score(
d["satisfaction score"]["label"],
d["satisfaction score"]["preds"], average=None, labels=all_sentiment)
cm = metrics.confusion_matrix(
label, preds, normalize="true", labels=all_sentiment)
d["satisfaction score"]["label"],
d["satisfaction score"]["preds"], normalize="true", labels=all_sentiment)
disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm,
display_labels=all_sentiment)
......@@ -136,7 +140,66 @@ def main():
"cm": [list(c) for c in list(cm)]}
print(r)
time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
plt.savefig(os.path.join(args.model, f"{time}-emowoz.png"))
plt.savefig(os.path.join(model, f"{time}-emowoz.png"))
def utterance(model, d):
bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in d["utterance generation"]["label"]]
bleu_score = bleu_metric.compute(
predictions=d["utterance generation"]["preds"],
references=labels,
force=True)
print(f"{model} bleu_score", bleu_score)
def action(model, d):
score = {}
for preds, label in zip(d["action prediction"]["preds"], d["action prediction"]["label"]):
s = f1_score(preds, label)
for n, v in s.items():
if n not in score:
score[n] = []
score[n].append(v)
print(f"{model} action")
for n, v in score.items():
print(n, np.mean(v))
def f1_score(prediction, label):
score = {}
tp = 0
pre = prediction.split(',')
lab = label.split(',')
for p in pre:
if p in lab:
tp += 1
score["precision"] = tp/len(pre)
score["recall"] = tp/len(lab)
score["F1"] = 0
if score["precision"]+score["recall"] > 0:
score["F1"] = 2*score["precision"]*score["recall"] / \
(score["precision"]+score["recall"])
if pre == lab:
score["acc"] = 1
else:
score["acc"] = 0
return score
def main():
args = arg_parser()
if args.gen_file:
d = read_result(json.load(open(args.gen_file)))
else:
data = build_data(load_experiment_dataset(args.data)["test"])
results = generate_result(args.model, data, args.stop)
d = read_result(results)
model = args.model
satisfaction(model, d)
utterance(model, d)
action(model, d)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment