diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoTUS/self_bleu.py index 4a245cc059e19bd717f532fadae121899dc322ff..0353c0d2693537ba61bacbb98bae5f45db6b2975 100644 --- a/convlab/policy/emoTUS/self_bleu.py +++ b/convlab/policy/emoTUS/self_bleu.py @@ -22,22 +22,23 @@ def get_sent(candidates): else: return [x["gen_utts"] for x in candidates["dialog"]] + def SelfBLEU(sentences): metric = load_metric("sacrebleu") result = [] - for i, sent in tqdm(enumerate(sentences),ascii=True): - r = metric.compute(predictions=[sent], references=[sentences[i:]+sentences[i+1:]]) + for i, sent in tqdm(enumerate(sentences), ascii=True): + r = metric.compute(predictions=[sent], references=[ + sentences[i:]+sentences[i+1:]]) result.append(r["score"]) - return sum(result)/len(result) def calculate(candidates): sentences = get_sent(candidates) bleu = SelfBLEU(sentences) - x = bleu.get_score() - print(x) + # x = bleu.get_score() + print(bleu) if __name__ == "__main__":