From 47aca1d9cc89a0c1b19331775460a37981abb0bf Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Wed, 25 Jan 2023 15:59:56 +0100 Subject: [PATCH] wip --- convlab/policy/emoTUS/self_bleu.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py index b77c4e11..47bcd9df 100644 --- a/convlab/policy/emoTUS/self_bleu.py +++ b/convlab/policy/emoTUS/self_bleu.py @@ -17,11 +17,17 @@ def read_file(file_name): return nlg_candidates -def get_sent(candidates): - if "log" in candidates: - return [x["gen_utts"] for x in candidates["log"]] +def get_sent(candidates, bleu_mode="torch"): + if bleu_mode == "torch": + if "log" in candidates: + return [x["gen_utts"] for x in candidates["log"]] + else: + return [x["gen_utts"] for x in candidates["dialog"]] else: - return [x["gen_utts"] for x in candidates["dialog"]] + if "log" in candidates: + return [x["gen_utts"].split() for x in candidates["log"]] + else: + return [x["gen_utts"].split() for x in candidates["dialog"]] def SelfBLEU(sentences): @@ -36,7 +42,7 @@ def SelfBLEU(sentences): def calculate(candidates, bleu_mode="torch"): - sentences = get_sent(candidates) + sentences = get_sent(candidates, bleu_mode) if bleu_mode == "torch": x = SelfBLEU(sentences) else: -- GitLab