diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 715553fa6be7626daae12d10fa3887f878c7bbda..a8a59da8e2df4309c0c7b4a9bc02cb7d97959d42 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -235,7 +235,7 @@ def emotion_score(golden_emotions, gen_emotions): macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro") sep_f1 = metrics.f1_score( golden_emotions, gen_emotions, average=None, labels=labels) - cm = metrics.confusion_matrix(golden_emotions, gen_emotions, labels) + cm = metrics.confusion_matrix(golden_emotions, gen_emotions, labels=labels) disp = metrics.ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=labels) disp.plot()