From 97ff16b97575a1ad6722375877981fa2750ca397 Mon Sep 17 00:00:00 2001 From: linh <linh@hpc-login7.hilbert.hpc.uni-duesseldorf.de> Date: Wed, 4 Jan 2023 17:19:00 +0100 Subject: [PATCH] change file name --- convlab/policy/emoTUS/evaluate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index b7fd63a4..e19eeecf 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -218,7 +218,7 @@ class Evaluator: for metric in scores: result[metric] = sum(scores[metric])/len(scores[metric]) print(f"{metric}: {result[metric]}") - emo_score = emotion_score(golden_emotions, gen_emotions) + emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint) # for metric in emo_score: # result[metric] = emo_score[metric] # print(f"{metric}: {result[metric]}") @@ -230,7 +230,7 @@ class Evaluator: self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) -def emotion_score(golden_emotions, gen_emotions): +def emotion_score(golden_emotions, gen_emotions, dirname="."): labels = ["Neutral", "Fearful", "Dissatisfied", "Apologetic", "Abusive", "Excited", "Satisfied"] print(labels) @@ -241,7 +241,7 @@ def emotion_score(golden_emotions, gen_emotions): disp = metrics.ConfusionMatrixDisplay( confusion_matrix=cm, display_labels=labels) disp.plot() - plt.savefig("emotion.png") + plt.savefig(os.path.join(dirname, "emotion.png")) r = {"macro_f1": float(macro_f1), "sep_f1": list( sep_f1), "cm": [list(c) for c in list(cm)]} print(r) -- GitLab