Skip to content
Snippets Groups Projects
Commit dbef82de authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

change evaluation script

parent 6cfe60d6
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,9 @@ from argparse import ArgumentParser
from pprint import pprint
import torch
from convlab.nlg.evaluate import fine_SER, get_bleu4
from convlab.nlg.evaluate import fine_SER
from datasets import load_metric
# from convlab.policy.genTUS.pg.stepGenTUSagent import \
# stepGenTUSPG as UserPolicy
from convlab.policy.genTUS.stepGenTUS import UserActionPolicy
......@@ -52,7 +54,13 @@ class Evaluator:
def generate_results(self, f_eval, golden=False):
in_file = json.load(open(f_eval))
dialog_acts, golden_utts, gen_utts = [], [], []
r = {
"input": [],
"golden_acts": [],
"golden_utts": [],
"gen_acts": [],
"gen_utts": []
}
for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
......@@ -66,113 +74,140 @@ class Evaluator:
self.usr._generate_action(inputs))
usr_act = self.usr._remove_illegal_action(output["action"])
usr_utt = output["text"]
r["input"].append(inputs)
r["golden_acts"].append(labels["action"])
r["golden_utts"].append(labels["text"])
r["gen_acts"].append(usr_act)
r["gen_utts"].append(usr_utt)
dialog_acts.append(usr_act)
golden_utts.append(labels["text"])
gen_utts.append(usr_utt)
return dialog_acts, golden_utts, gen_utts
return r
def self_ser(self, f_eval):
def read_generated_result(self, f_eval):
in_file = json.load(open(f_eval))
dialog_acts, golden_utts, gen_utts = [], [], []
r = {
"input": [],
"golden_acts": [],
"golden_utts": [],
"gen_acts": [],
"gen_utts": []
}
for dialog in tqdm(in_file['dialog']):
dialog_acts.append(dialog["predict_action"])
golden_utts.append(dialog["answer_text"])
gen_utts.append(dialog["predict_text"])
return dialog_acts, golden_utts, gen_utts
for x in dialog:
r[x].append(dialog[x])
return r
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
if input_file:
print("Force generation")
dialog_acts, golden_utts, gen_utts = self.generate_results(
input_file, golden)
r = {'dialog': []}
for act, ans, pre in zip(dialog_acts, golden_utts, gen_utts):
r['dialog'].append({"predict_action": act,
"answer_text": ans,
"predict_text": pre})
if generated_file:
print(f"update result in {generated_file}")
else:
generated_file = os.path.join(
self.model_checkpoint, 'generation_results.json')
print(f"dump result to {generated_file}")
json.dump(r, open(generated_file, 'w'), indent=2)
gen_r = self.generate_results(input_file, golden)
elif generated_file:
dialog_acts, golden_utts, gen_utts = self.self_ser(generated_file)
gen_r = self.read_generated_result(generated_file)
else:
print("You must specify the input_file or the generated_file")
nlg_eval = {
"dialog_acts": dialog_acts,
"golden_utts": golden_utts,
"gen_utts": gen_utts
"golden": golden,
"metrics": {},
"dialog": []
}
for input, golden_act, golden_utt, gen_act, gen_utt in zip(gen_r["input"], gen_r["golden_acts"], gen_r["golden_utts"], gen_r["gen_acts"], gen_r["gen_utts"]):
nlg_eval["dialog"].append({
"input": input,
"golden_acts": golden_act,
"golden_utts": golden_utt,
"gen_acts": gen_act,
"gen_utts": gen_utt
})
print("Calculate SER for golden responses")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
nlg_eval["dialog_acts"], nlg_eval["golden_utts"])
print("Golden response Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
missing, total, hallucinate, missing/total))
if golden:
print("Calculate BLEU")
bleu_metric = load_metric("sacrebleu")
labels = [[utt] for utt in gen_r["golden_utts"]]
bleu_score = bleu_metric.compute(predictions=gen_r["gen_utts"],
references=labels,
force=True)
print("bleu_metric", bleu_score)
nlg_eval["metrics"]["bleu"] = bleu_score
else:
print("Calculate SER")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(
nlg_eval["dialog_acts"], nlg_eval["gen_utts"])
print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(
"genTUSNLG", missing, total, hallucinate, missing/total))
bleu4 = get_bleu4(nlg_eval["dialog_acts"],
nlg_eval["golden_utts"], nlg_eval["gen_utts"])
print("BLEU-4: %.4f" % bleu4)
nlg_eval["metrics"]["SER"] = missing/total
dir_name = self.model_checkpoint
json.dump(nlg_eval,
open(os.path.join(dir_name, "nlg_eval.json"), 'w'))
open(os.path.join(dir_name, "nlg_eval.json"), 'w'),
indent=2)
def evaluation(self, f_eval):
in_file = json.load(open(f_eval))
def evaluation(self, input_file=None, generated_file=None):
force_prediction = True
if generated_file:
gen_file = json.load(open(generated_file))
force_prediction = False
if gen_file["golden"]:
force_prediction = True
if force_prediction:
in_file = json.load(open(input_file))
dialog_result = []
result = {}
scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
gen_acts, golden_acts = [], []
# scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
for dialog in tqdm(in_file['dialog']):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
ans_action = self.usr._remove_illegal_action(labels["action"])
preds = self.usr._generate_action(inputs)
preds = self.usr._parse_output(preds)
# print("inputs", inputs)
# print("goal_list", self.usr.kg.user_goal)
usr_action = self.usr._remove_illegal_action(preds["action"])
# print("usr", usr_action)
# print("ans", ans_action)
s = f1_measure(preds=usr_action, labels=ans_action)
for metric in scores:
scores[metric].append(s[metric])
print("ans", ans_action)
print("pre", usr_action)
d = {"in": inputs,
"answer_action": ans_action,
"predict_action": usr_action}
gen_acts.append(usr_action)
golden_acts.append(ans_action)
d = {"input": inputs,
"golden_acts": ans_action,
"gen_acts": usr_action}
if "text" in preds:
d["answer_text"] = labels["text"]
d["predict_text"] = preds["text"]
d["golden_utts"] = labels["text"]
d["gen_utts"] = preds["text"]
# print("pred text", preds["text"])
dialog_result.append(d)
else:
gen_acts, golden_acts = [], []
for gen_act, golden_act in zip(gen_file['dialog']["gen_acts"], gen_file['dialog']["golden_acts"]):
gen_acts.append(usr_action)
golden_acts.append(ans_action)
dialog_result = gen_file['dialog']
scores = {"precision": [], "recall": [], "f1": [], "turn_acc": []}
for gen_act, golden_act in zip(gen_acts, golden_acts):
s = f1_measure(preds=gen_act, labels=golden_act)
for metric in scores:
scores[metric].append(s[metric])
result = {}
for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}")
result["dialog"] = dialog_result
basename = "evaluation_result"
if self.model_weight:
json.dump(result, open(os.path.join(
'results', f"{basename}.json"), 'w'))
else:
basename = "semantic_evaluation_result"
json.dump(result, open(os.path.join(
self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
# if self.model_weight:
# json.dump(result, open(os.path.join(
# 'results', f"{basename}.json"), 'w'))
# else:
# json.dump(result, open(os.path.join(
# self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
def f1_measure(preds, labels):
......@@ -199,6 +234,7 @@ def main():
args.dataset,
args.model_weight,
args.only_action)
print("model checkpoint", args.model_checkpoint)
print("generated_file", args.generated_file)
print("input_file", args.input_file)
with torch.no_grad():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment