diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py index 0353c0d2693537ba61bacbb98bae5f45db6b2975..b77c4e115b953a12e24d55891c13f27974cb2c00 100644 --- a/convlab/policy/emoTUS/self_bleu.py +++ b/convlab/policy/emoTUS/self_bleu.py @@ -8,6 +8,7 @@ from tqdm import tqdm def arg_parser(): parser = argparse.ArgumentParser() parser.add_argument("--file", type=str) + parser.add_argument("--fast-bleu", action="store_true") return parser.parse_args() @@ -34,13 +35,21 @@ def SelfBLEU(sentences): return sum(result)/len(result) -def calculate(candidates): +def calculate(candidates, bleu_mode="torch"): sentences = get_sent(candidates) - bleu = SelfBLEU(sentences) + if bleu_mode == "torch": + x = SelfBLEU(sentences) + else: + bleu = fast_bleu.SelfBLEU(sentences) + x = bleu.get_score() # x = bleu.get_score() - print(bleu) + print(x) if __name__ == "__main__": args = arg_parser() - calculate(read_file(args.file)) + if args.fast_bleu: + import fast_bleu + calculate(read_file(args.file), "fast-bleu") + else: + calculate(read_file(args.file))