From 47aca1d9cc89a0c1b19331775460a37981abb0bf Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Wed, 25 Jan 2023 15:59:56 +0100
Subject: [PATCH] wip

---
 convlab/policy/emoTUS/self_bleu.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py
index b77c4e11..47bcd9df 100644
--- a/convlab/policy/emoTUS/self_bleu.py
+++ b/convlab/policy/emoTUS/self_bleu.py
@@ -17,11 +17,17 @@ def read_file(file_name):
     return nlg_candidates
 
 
-def get_sent(candidates):
-    if "log" in candidates:
-        return [x["gen_utts"] for x in candidates["log"]]
+def get_sent(candidates, bleu_mode="torch"):
+    if bleu_mode == "torch":
+        if "log" in candidates:
+            return [x["gen_utts"] for x in candidates["log"]]
+        else:
+            return [x["gen_utts"] for x in candidates["dialog"]]
     else:
-        return [x["gen_utts"] for x in candidates["dialog"]]
+        if "log" in candidates:
+            return [x["gen_utts"].split() for x in candidates["log"]]
+        else:
+            return [x["gen_utts"].split() for x in candidates["dialog"]]
 
 
 def SelfBLEU(sentences):
@@ -36,7 +42,7 @@ def SelfBLEU(sentences):
 
 
 def calculate(candidates, bleu_mode="torch"):
-    sentences = get_sent(candidates)
+    sentences = get_sent(candidates, bleu_mode)
     if bleu_mode == "torch":
         x = SelfBLEU(sentences)
     else:
-- 
GitLab