diff --git a/convlab2/nlg/evaluate_unified_datasets.py b/convlab2/nlg/evaluate_unified_datasets.py index 201181f6d03fc1dea976823bb0465b0ddf065cbb..544c3c37cff10e444869e8579d7210fa5d256032 100644 --- a/convlab2/nlg/evaluate_unified_datasets.py +++ b/convlab2/nlg/evaluate_unified_datasets.py @@ -1,5 +1,6 @@ import sys from nltk.translate.bleu_score import corpus_bleu +import sacrebleu from nltk.tokenize import word_tokenize sys.path.append('../..') import json @@ -32,9 +33,10 @@ def evaluate(predict_result, ontology): references = [] candidates = [] for i in range(len(predict_result)): - references.append([word_tokenize(predict_result[i]['utterance'])]) - candidates.append(word_tokenize(predict_result[i]['prediction'])) - metrics['bleu'] = corpus_bleu(references, candidates) + references.append(predict_result[i]['utterance']) + candidates.append(predict_result[i]['prediction']) + # metrics['bleu'] = corpus_bleu(references, candidates) + metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score # ERROR Rate ## get all values in ontology