diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index b7fd63a4ea4fa7dd7f03e5aa1f0e262a8699cca7..e19eeecfec21d5a490c1296a7b4e207ec9b6ea4b 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -218,7 +218,7 @@ class Evaluator: for metric in scores: result[metric] = sum(scores[metric])/len(scores[metric]) print(f"{metric}: {result[metric]}") - emo_score = emotion_score(golden_emotions, gen_emotions) + emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint) # for metric in emo_score: # result[metric] = emo_score[metric] # print(f"{metric}: {result[metric]}") @@ -230,7 +230,7 @@ class Evaluator: self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) -def emotion_score(golden_emotions, gen_emotions): +def emotion_score(golden_emotions, gen_emotions, dirname="."): labels = ["Neutral", "Fearful", "Dissatisfied", "Apologetic", "Abusive", "Excited", "Satisfied"] print(labels) @@ -241,7 +241,7 @@ def emotion_score(golden_emotions, gen_emotions): disp = metrics.ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=labels) disp.plot() - plt.savefig("emotion.png") + plt.savefig(os.path.join(dirname, "emotion.png")) r = {"macro_f1": float(macro_f1), "sep_f1": list( sep_f1), "cm": [list(c) for c in list(cm)]} print(r)