diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index c5cd5fbe771d715fbaa9ee28717f74f76b6dbbad..579315f76f36ffa9c77fda8e0a2b05bc30c75ec8 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -224,6 +224,7 @@ class Evaluator: print(f"{metric}: {result[metric]}") result["dialog"] = dialog_result + basename = "semantic_evaluation_result" json.dump(result, open(os.path.join( self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) @@ -241,7 +242,7 @@ def emotion_score(golden_emotions, gen_emotions): confusion_matrix=cm, display_labels=labels) disp.plot() plt.savefig("emotion.png") - r = {"macro_f1": macro_f1, "sep_f1": list( + r = {"macro_f1": float(macro_f1), "sep_f1": list( sep_f1), "cm": [list(c) for c in list(cm)]} print(r) return r