diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index fca0f3a9dc3e44283ebcf959d74bf7eeaeeec61b..2fb0790f19b07c50c97b3aa652db0615ddc8dd60 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -91,7 +91,6 @@ class UserActionPolicy(GenTUSUserActionPolicy): raw_output = self._generate_action( raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) output = self._parse_output(raw_output) - print(output) self.semantic_action = self._remove_illegal_action(output["action"]) self.utterance = output["text"] self.emotion = output["emotion"] diff --git a/convlab/policy/emoTUS/emotion_eval.py b/convlab/policy/emoTUS/emotion_eval.py index 0d78f16a0beee92bca4adb46e60997411288b1d2..da1479f14dc955367a5915cfc586d02b7b1d4cea 100644 --- a/convlab/policy/emoTUS/emotion_eval.py +++ b/convlab/policy/emoTUS/emotion_eval.py @@ -184,8 +184,8 @@ class Evaluator: scores = {} for emotion in self.emotion_list: - if emotion == "Neutral": - continue + # if emotion == "Neutral": + # continue scores[emotion] = {"precision": [], "recall": [], "f1": [], "turn_acc": []} for gen_act, golden_act in zip(r[f"{emotion}_acts"], r["Neutral_acts"]): @@ -195,16 +195,23 @@ class Evaluator: result = {} for emotion in self.emotion_list: - if emotion == "Neutral": - continue + # if emotion == "Neutral": + # continue result[emotion] = {} + for metric in scores[emotion]: + result[emotion][metric] = sum( + scores[emotion][metric])/len(scores[emotion][metric]) result[emotion]["bleu"] = bleu(golden_utts=r["Neutral_utts"], gen_utts=r[f"{emotion}_utts"]) result[emotion]["SER"] = SER(gen_utts=r[f"{emotion}_utts"], gen_acts=r[f"{emotion}_acts"]) - for metric in scores[emotion]: - result[emotion][metric] = sum( - scores[emotion][metric])/len(scores[emotion][metric]) + + result[emotion]["len"] = avg_len(gen_utts=r[f"{emotion}_utts"]) + + rouge_score = rouge(golden_utts=r["Neutral_utts"], + gen_utts=r[f"{emotion}_utts"]) + for metric, score in rouge_score.items(): + result[emotion][metric] = score.mid.fmeasure print("emotion:", emotion) for metric in result[emotion]: @@ -221,6 +228,11 @@ class Evaluator: self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'), indent=2) +def avg_len(gen_utts): + n = [len(s.split()) for s in gen_utts] + return sum(n)/len(n) + + def bleu(golden_utts, gen_utts): bleu_metric = load_metric("sacrebleu") labels = [[utt] for utt in golden_utts] @@ -231,6 +243,13 @@ def bleu(golden_utts, gen_utts): return bleu_score["score"] +def rouge(golden_utts, gen_utts): + rouge_metric = load_metric("rouge") + rouge_score = rouge_metric.compute(predictions=gen_utts, + references=golden_utts) + return rouge_score + + def SER(gen_utts, gen_acts): missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER( gen_acts, gen_utts)