Skip to content
Snippets Groups Projects
Commit 15b002b0 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent 2c5ddaec
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ def arg_parser(): ...@@ -9,6 +9,7 @@ def arg_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str) parser.add_argument("--file", type=str)
parser.add_argument("--fast-bleu", action="store_true") parser.add_argument("--fast-bleu", action="store_true")
parser.add_argument("--uss", action="store_true")
return parser.parse_args() return parser.parse_args()
...@@ -17,13 +18,17 @@ def read_file(file_name): ...@@ -17,13 +18,17 @@ def read_file(file_name):
return nlg_candidates return nlg_candidates
def get_sent(candidates, bleu_mode="torch"): def get_sent(candidates, bleu_mode="torch", uss=False):
if bleu_mode == "torch": if bleu_mode == "torch":
if uss:
return [x["preds"] for x in candidates]
if "log" in candidates: if "log" in candidates:
return [x["gen_utts"] for x in candidates["log"]] return [x["gen_utts"] for x in candidates["log"]]
else: else:
return [x["gen_utts"] for x in candidates["dialog"]] return [x["gen_utts"] for x in candidates["dialog"]]
else: else:
if uss:
return [x["preds"].split() for x in candidates]
if "log" in candidates: if "log" in candidates:
return [x["gen_utts"].split() for x in candidates["log"]] return [x["gen_utts"].split() for x in candidates["log"]]
else: else:
...@@ -41,20 +46,22 @@ def SelfBLEU(sentences): ...@@ -41,20 +46,22 @@ def SelfBLEU(sentences):
return sum(result)/len(result) return sum(result)/len(result)
def calculate(candidates, bleu_mode="torch"): def calculate(candidates, bleu_mode="torch", uss=False):
sentences = get_sent(candidates, bleu_mode) sentences = get_sent(candidates, bleu_mode, uss)
if bleu_mode == "torch": if bleu_mode == "torch":
x = SelfBLEU(sentences) x = SelfBLEU(sentences)
else: else:
bleu = fast_bleu.SelfBLEU(sentences) bleu = fast_bleu.SelfBLEU(sentences)
x = bleu.get_score() x = bleu.get_score()
# x = bleu.get_score() # x = bleu.get_score()
# print(x)
print(sum(x[4])/len(x[4])) print(sum(x[4])/len(x[4]))
if __name__ == "__main__": if __name__ == "__main__":
args = arg_parser() args = arg_parser()
if args.fast_bleu: if args.fast_bleu:
import fast_bleu import fast_bleu
calculate(read_file(args.file), "fast-bleu") calculate(read_file(args.file), "fast-bleu", uss=args.uss)
else: else:
calculate(read_file(args.file)) calculate(read_file(args.file), uss=args.uss)
...@@ -2,6 +2,7 @@ import json ...@@ -2,6 +2,7 @@ import json
import os import os
from argparse import ArgumentParser from argparse import ArgumentParser
from datetime import datetime from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
...@@ -12,6 +13,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer ...@@ -12,6 +13,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset from convlab.policy.tus.unify.util import create_goal, load_experiment_dataset
from convlab.policy.ussT5.evaluate import tri_convert from convlab.policy.ussT5.evaluate import tri_convert
from datasets import load_metric
def arg_parser(): def arg_parser():
parser = ArgumentParser() parser = ArgumentParser()
...@@ -104,29 +107,30 @@ def generate_result(model_checkpoint, data, stop=-1): ...@@ -104,29 +107,30 @@ def generate_result(model_checkpoint, data, stop=-1):
def read_result(result): def read_result(result):
preds = [] d = {}
label = [] for d_name in ["satisfaction score", "utterance generation", "action prediction"]:
d[d_name] = {"preds": [], "label": []}
for r in result: for r in result:
if "satisfaction score" in r["input_text"]: for d_name in ["satisfaction score", "utterance generation", "action prediction"]:
preds.append(r["preds"]) if d_name in r["input_text"]:
label.append(r["label"]) d[d_name]["preds"].append(r["preds"])
return preds, label d[d_name]["label"].append(r["label"])
return d
def main(): def satisfaction(model, d):
args = arg_parser() # satisfaction
if args.gen_file:
preds, label = read_result(json.load(open(args.gen_file)))
else:
data = build_data(load_experiment_dataset(args.data)["test"])
results = generate_result(args.model, data, args.stop)
preds, label = read_result(results)
all_sentiment = ["Neutral", "Negative", "Positive"] all_sentiment = ["Neutral", "Negative", "Positive"]
print(all_sentiment) print(all_sentiment)
tri_f1 = metrics.f1_score(label, preds, average="macro") tri_f1 = metrics.f1_score(
sep_f1 = metrics.f1_score(label, preds, average=None, labels=all_sentiment) d["satisfaction score"]["label"],
d["satisfaction score"]["preds"], average="macro")
sep_f1 = metrics.f1_score(
d["satisfaction score"]["label"],
d["satisfaction score"]["preds"], average=None, labels=all_sentiment)
cm = metrics.confusion_matrix( cm = metrics.confusion_matrix(
label, preds, normalize="true", labels=all_sentiment) d["satisfaction score"]["label"],
d["satisfaction score"]["preds"], normalize="true", labels=all_sentiment)
disp = metrics.ConfusionMatrixDisplay( disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, confusion_matrix=cm,
display_labels=all_sentiment) display_labels=all_sentiment)
...@@ -136,7 +140,66 @@ def main(): ...@@ -136,7 +140,66 @@ def main():
"cm": [list(c) for c in list(cm)]} "cm": [list(c) for c in list(cm)]}
print(r) print(r)
time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}" time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
plt.savefig(os.path.join(args.model, f"{time}-emowoz.png")) plt.savefig(os.path.join(model, f"{time}-emowoz.png"))
def utterance(model, d):
bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in d["utterance generation"]["label"]]
bleu_score = bleu_metric.compute(
predictions=d["utterance generation"]["preds"],
references=labels,
force=True)
print(f"{model} bleu_score", bleu_score)
def action(model, d):
score = {}
for preds, label in zip(d["action prediction"]["preds"], d["action prediction"]["label"]):
s = f1_score(preds, label)
for n, v in s.items():
if n not in score:
score[n] = []
score[n].append(v)
print(f"{model} action")
for n, v in score.items():
print(n, np.mean(v))
def f1_score(prediction, label):
score = {}
tp = 0
pre = prediction.split(',')
lab = label.split(',')
for p in pre:
if p in lab:
tp += 1
score["precision"] = tp/len(pre)
score["recall"] = tp/len(lab)
score["F1"] = 0
if score["precision"]+score["recall"] > 0:
score["F1"] = 2*score["precision"]*score["recall"] / \
(score["precision"]+score["recall"])
if pre == lab:
score["acc"] = 1
else:
score["acc"] = 0
return score
def main():
args = arg_parser()
if args.gen_file:
d = read_result(json.load(open(args.gen_file)))
else:
data = build_data(load_experiment_dataset(args.data)["test"])
results = generate_result(args.model, data, args.stop)
d = read_result(results)
model = args.model
satisfaction(model, d)
utterance(model, d)
action(model, d)
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment