From cf02ef68fde70a82e41d5223a86eb15b4619e7e2 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Wed, 25 Jan 2023 11:28:21 +0100 Subject: [PATCH] add self_bleu --- convlab/policy/emoTUS/self_bleu.py | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 convlab/policy/emoTUS/self_bleu.py diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py new file mode 100644 index 00000000..4a245cc0 --- /dev/null +++ b/convlab/policy/emoTUS/self_bleu.py @@ -0,0 +1,45 @@ +# from fast_bleu import SelfBLEU +import argparse +import json +from datasets import Dataset, load_metric +from tqdm import tqdm + + +def arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--file", type=str) + return parser.parse_args() + + +def read_file(file_name): + nlg_candidates = json.load(open(file_name)) + return nlg_candidates + + +def get_sent(candidates): + if "log" in candidates: + return [x["gen_utts"] for x in candidates["log"]] + else: + return [x["gen_utts"] for x in candidates["dialog"]] + +def SelfBLEU(sentences): + metric = load_metric("sacrebleu") + result = [] + for i, sent in tqdm(enumerate(sentences),ascii=True): + r = metric.compute(predictions=[sent], references=[sentences[i:]+sentences[i+1:]]) + result.append(r["score"]) + + + return sum(result)/len(result) + + +def calculate(candidates): + sentences = get_sent(candidates) + bleu = SelfBLEU(sentences) + x = bleu.get_score() + print(x) + + +if __name__ == "__main__": + args = arg_parser() + calculate(read_file(args.file)) -- GitLab