diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py index 37de963ffd2addaf011cb826875a3a2b70889f89..4e7f038d29425f21299c2a3e9ab8158401cdc3a8 100644 --- a/convlab/policy/emoTUS/self_bleu.py +++ b/convlab/policy/emoTUS/self_bleu.py @@ -17,11 +17,17 @@ def read_file(file_name): return nlg_candidates -def get_sent(candidates): - if "log" in candidates: - return [x["gen_utts"] for x in candidates["log"]] +def get_sent(candidates, bleu_mode="torch"): + if bleu_mode == "torch": + if "log" in candidates: + return [x["gen_utts"] for x in candidates["log"]] + else: + return [x["gen_utts"] for x in candidates["dialog"]] else: - return [x["gen_utts"] for x in candidates["dialog"]] + if "log" in candidates: + return [x["gen_utts"].split() for x in candidates["log"]] + else: + return [x["gen_utts"].split() for x in candidates["dialog"]] def SelfBLEU(sentences): @@ -36,7 +42,7 @@ def SelfBLEU(sentences): def calculate(candidates, bleu_mode="torch"): - sentences = get_sent(candidates) + sentences = get_sent(candidates, bleu_mode) if bleu_mode == "torch": x = SelfBLEU(sentences) else: