Skip to content
Snippets Groups Projects
Commit 40d3d6aa authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent aa78d7c5
No related branches found
No related tags found
No related merge requests found
......@@ -2,18 +2,19 @@ import json
import os
import sys
from argparse import ArgumentParser
import matplotlib.pyplot as plt
from datetime import datetime
import matplotlib.pyplot as plt
import torch
from convlab.nlg.evaluate import fine_SER
from datasets import load_metric
# from convlab.policy.genTUS.pg.stepGenTUSagent import \
# stepGenTUSPG as UserPolicy
from sklearn import metrics
from convlab.policy.emoTUS.emoTUS import UserActionPolicy
from tqdm import tqdm
from convlab.nlg.evaluate import fine_SER
from convlab.policy.emoTUS.emoTUS import UserActionPolicy
sys.path.append(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))))
......@@ -45,6 +46,7 @@ class Evaluator:
self.dataset = dataset
self.model_checkpoint = model_checkpoint
self.model_weight = model_weight
self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
# if model_weight:
# self.usr_policy = UserPolicy(
# self.model_checkpoint, only_action=only_action)
......@@ -169,9 +171,9 @@ class Evaluator:
dir_name = self.model_checkpoint
json.dump(nlg_eval,
open(os.path.join(dir_name, "nlg_eval.json"), 'w'),
open(os.path.join(dir_name, f"{self.time}-nlg_eval.json"), 'w'),
indent=2)
return os.path.join(dir_name, "nlg_eval.json")
return os.path.join(dir_name, f"{self.time}-nlg_eval.json")
def evaluation(self, input_file=None, generated_file=None):
# TODO add emotion
......@@ -229,7 +231,7 @@ class Evaluator:
for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}")
emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint)
emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint, time=self.time)
# for metric in emo_score:
# result[metric] = emo_score[metric]
# print(f"{metric}: {result[metric]}")
......@@ -238,10 +240,10 @@ class Evaluator:
basename = "semantic_evaluation_result"
json.dump(result, open(os.path.join(
self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w'))
self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'))
def emotion_score(golden_emotions, gen_emotions, dirname="."):
def emotion_score(golden_emotions, gen_emotions, dirname=".", time=""):
labels = ["Neutral", "Fearful", "Dissatisfied",
"Apologetic", "Abusive", "Excited", "Satisfied"]
print(labels)
......@@ -252,7 +254,7 @@ def emotion_score(golden_emotions, gen_emotions, dirname="."):
disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.savefig(os.path.join(dirname, "emotion.png"))
plt.savefig(os.path.join(dirname, f"{time}-emotion.png"))
r = {"macro_f1": float(macro_f1), "sep_f1": list(
sep_f1), "cm": [list(c) for c in list(cm)]}
print(r)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment