diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py index b77c4e115b953a12e24d55891c13f27974cb2c00..47bcd9df0f5022ac66eac8e75ac3939941127887 100644 --- a/convlab/policy/emoTUS/self_bleu.py +++ b/convlab/policy/emoTUS/self_bleu.py @@ -17,11 +17,17 @@ def read_file(file_name): return nlg_candidates -def get_sent(candidates): - if "log" in candidates: - return [x["gen_utts"] for x in candidates["log"]] +def get_sent(candidates, bleu_mode="torch"): + if bleu_mode == "torch": + if "log" in candidates: + return [x["gen_utts"] for x in candidates["log"]] + else: + return [x["gen_utts"] for x in candidates["dialog"]] else: - return [x["gen_utts"] for x in candidates["dialog"]] + if "log" in candidates: + return [x["gen_utts"].split() for x in candidates["log"]] + else: + return [x["gen_utts"].split() for x in candidates["dialog"]] def SelfBLEU(sentences): @@ -36,7 +42,7 @@ def SelfBLEU(sentences): def calculate(candidates, bleu_mode="torch"): - sentences = get_sent(candidates) + sentences = get_sent(candidates, bleu_mode) if bleu_mode == "torch": x = SelfBLEU(sentences) else: