From 77e39f4eec98671ccdfd044823b86030cbfc7508 Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Fri, 14 Apr 2023 11:37:49 +0200 Subject: [PATCH] fix typo --- convlab/policy/emoUS/evaluate.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/convlab/policy/emoUS/evaluate.py b/convlab/policy/emoUS/evaluate.py index bd97ab2b..a516f874 100644 --- a/convlab/policy/emoUS/evaluate.py +++ b/convlab/policy/emoUS/evaluate.py @@ -9,6 +9,7 @@ import torch from datasets import load_metric from sklearn import metrics from tqdm import tqdm +from pprint import pprint from convlab.nlg.evaluate import fine_SER from convlab.policy.emoUS.emoUS import UserActionPolicy @@ -194,13 +195,13 @@ class Evaluator: def dialog_result(self, dialog): x = {"gen_acts": [], "golden_acts": [], - "gen_emotions": [], "golden_emotions": []} + "gen_emotion": [], "golden_emotion": []} for d in dialog: x["gen_acts"].append(d["gen_acts"]) x["golden_acts"].append(d["golden_acts"]) - x["gen_emotions"].append(d["gen_emotion"]) - x["golden_emotions"].append(d["golden_emotion"]) + x["gen_emotion"].append(d["gen_emotion"]) + x["golden_emotion"].append(d["golden_emotion"]) return x def semantic_evaluation(self, x): @@ -246,8 +247,8 @@ class Evaluator: self.evaluation_result["semantic action prediction"][metric] = score if not golden_emotion and not golden_action: - r = emotion_score(x["golden_emotions"], - x["gen_emotions"], + r = emotion_score(x["golden_emotion"], + x["gen_emotion"], self.model_checkpoint) self.evaluation_result["emotion prediction"]["emotion"] = {} self.evaluation_result["emotion prediction"]["emotion"]["macro_f1"] = r["macro_f1"] @@ -260,9 +261,9 @@ class Evaluator: else: # transfer emotions to sentiment if the model do not generate sentiment golden_sentiment = [self.emo2sent[emo] - for emo in self.r["golden_emotions"]] + for emo in self.r["golden_emotion"]] gen_sentiment = [self.emo2sent[emo] - for emo in self.r["gen_emotions"]] + for emo in self.r["gen_emotion"]] r = sentiment_score( golden_sentiment, gen_sentiment, @@ -273,7 +274,7 @@ class Evaluator: self.evaluation_result["emotion prediction"]["sentiment"]["sep_f1"] = { emo: f1 for emo, f1 in zip(r["label"], r["sep_f1"])} - print(self.evaluation_result) + pprint(self.evaluation_result) # def save_results(self): -- GitLab