Skip to content
Snippets Groups Projects
Commit 40d3d6aa authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent aa78d7c5
Branches
No related tags found
No related merge requests found
...@@ -2,18 +2,19 @@ import json ...@@ -2,18 +2,19 @@ import json
import os import os
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
import matplotlib.pyplot as plt from datetime import datetime
import matplotlib.pyplot as plt
import torch import torch
from convlab.nlg.evaluate import fine_SER
from datasets import load_metric from datasets import load_metric
# from convlab.policy.genTUS.pg.stepGenTUSagent import \ # from convlab.policy.genTUS.pg.stepGenTUSagent import \
# stepGenTUSPG as UserPolicy # stepGenTUSPG as UserPolicy
from sklearn import metrics from sklearn import metrics
from convlab.policy.emoTUS.emoTUS import UserActionPolicy
from tqdm import tqdm from tqdm import tqdm
from convlab.nlg.evaluate import fine_SER
from convlab.policy.emoTUS.emoTUS import UserActionPolicy
sys.path.append(os.path.dirname(os.path.dirname( sys.path.append(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))))) os.path.dirname(os.path.abspath(__file__)))))
...@@ -45,6 +46,7 @@ class Evaluator: ...@@ -45,6 +46,7 @@ class Evaluator:
self.dataset = dataset self.dataset = dataset
self.model_checkpoint = model_checkpoint self.model_checkpoint = model_checkpoint
self.model_weight = model_weight self.model_weight = model_weight
self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
# if model_weight: # if model_weight:
# self.usr_policy = UserPolicy( # self.usr_policy = UserPolicy(
# self.model_checkpoint, only_action=only_action) # self.model_checkpoint, only_action=only_action)
...@@ -169,9 +171,9 @@ class Evaluator: ...@@ -169,9 +171,9 @@ class Evaluator:
dir_name = self.model_checkpoint dir_name = self.model_checkpoint
json.dump(nlg_eval, json.dump(nlg_eval,
open(os.path.join(dir_name, "nlg_eval.json"), 'w'), open(os.path.join(dir_name, f"{self.time}-nlg_eval.json"), 'w'),
indent=2) indent=2)
return os.path.join(dir_name, "nlg_eval.json") return os.path.join(dir_name, f"{self.time}-nlg_eval.json")
def evaluation(self, input_file=None, generated_file=None): def evaluation(self, input_file=None, generated_file=None):
# TODO add emotion # TODO add emotion
...@@ -229,7 +231,7 @@ class Evaluator: ...@@ -229,7 +231,7 @@ class Evaluator:
for metric in scores: for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric]) result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}") print(f"{metric}: {result[metric]}")
emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint) emo_score = emotion_score(golden_emotions, gen_emotions, self.model_checkpoint, time=self.time)
# for metric in emo_score: # for metric in emo_score:
# result[metric] = emo_score[metric] # result[metric] = emo_score[metric]
# print(f"{metric}: {result[metric]}") # print(f"{metric}: {result[metric]}")
...@@ -238,10 +240,10 @@ class Evaluator: ...@@ -238,10 +240,10 @@ class Evaluator:
basename = "semantic_evaluation_result" basename = "semantic_evaluation_result"
json.dump(result, open(os.path.join( json.dump(result, open(os.path.join(
self.model_checkpoint, f"{self.dataset}-{basename}.json"), 'w')) self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'))
def emotion_score(golden_emotions, gen_emotions, dirname="."): def emotion_score(golden_emotions, gen_emotions, dirname=".", time=""):
labels = ["Neutral", "Fearful", "Dissatisfied", labels = ["Neutral", "Fearful", "Dissatisfied",
"Apologetic", "Abusive", "Excited", "Satisfied"] "Apologetic", "Abusive", "Excited", "Satisfied"]
print(labels) print(labels)
...@@ -252,7 +254,7 @@ def emotion_score(golden_emotions, gen_emotions, dirname="."): ...@@ -252,7 +254,7 @@ def emotion_score(golden_emotions, gen_emotions, dirname="."):
disp = metrics.ConfusionMatrixDisplay( disp = metrics.ConfusionMatrixDisplay(
confusion_matrix=cm, display_labels=labels) confusion_matrix=cm, display_labels=labels)
disp.plot() disp.plot()
plt.savefig(os.path.join(dirname, "emotion.png")) plt.savefig(os.path.join(dirname, f"{time}-emotion.png"))
r = {"macro_f1": float(macro_f1), "sep_f1": list( r = {"macro_f1": float(macro_f1), "sep_f1": list(
sep_f1), "cm": [list(c) for c in list(cm)]} sep_f1), "cm": [list(c) for c in list(cm)]}
print(r) print(r)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment