From 03f5ae09c2152a81e05474580ba3da0fb1c11c58 Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Wed, 25 Jan 2023 15:51:30 +0100
Subject: [PATCH] wip

---
 convlab/policy/emoTUS/self_bleu.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py
index 0353c0d2..b77c4e11 100644
--- a/convlab/policy/emoTUS/self_bleu.py
+++ b/convlab/policy/emoTUS/self_bleu.py
@@ -8,6 +8,7 @@ from tqdm import tqdm
 def arg_parser():
     parser = argparse.ArgumentParser()
     parser.add_argument("--file", type=str)
+    parser.add_argument("--fast-bleu", action="store_true")
     return parser.parse_args()
 
 
@@ -34,13 +35,21 @@ def SelfBLEU(sentences):
     return sum(result)/len(result)
 
 
-def calculate(candidates):
+def calculate(candidates, bleu_mode="torch"):
     sentences = get_sent(candidates)
-    bleu = SelfBLEU(sentences)
+    if bleu_mode == "torch":
+        x = SelfBLEU(sentences)
+    else:
+        bleu = fast_bleu.SelfBLEU(sentences)
+        x = bleu.get_score()
     # x = bleu.get_score()
-    print(bleu)
+    print(x)
 
 
 if __name__ == "__main__":
     args = arg_parser()
-    calculate(read_file(args.file))
+    if args.fast_bleu:
+        import fast_bleu
+        calculate(read_file(args.file), "fast-bleu")
+    else:
+        calculate(read_file(args.file))
-- 
GitLab