diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index d67a1e1a93fc9ccae28453351ee637331a1315ea..9b479ca2f627a085a5fa7d618098bf29f7045da2 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -232,6 +232,7 @@ class Evaluator: def emotion_score(golden_emotions, gen_emotions): labels = ["Neutral", "Disappointed", "Dissatisfied", "Apologetic", "Abusive", "Excited", "Satisfied"] + print(labels) macro_f1 = metrics.f1_score(golden_emotions, gen_emotions, average="macro") sep_f1 = metrics.f1_score( golden_emotions, gen_emotions, average=None, labels=labels) @@ -240,7 +241,7 @@ def emotion_score(golden_emotions, gen_emotions): confusion_matrix=cm, display_labels=labels) disp.plot() plt.savefig("emotion.png") - r = {"macro_f1": macro_f1, "sep_f1": sep_f1, "cm": cm} + r = {"macro_f1": macro_f1, "sep_f1": list(sep_f1), "cm": list(cm)} print(r) return r