Skip to content
Snippets Groups Projects
Commit 58bd6ce2 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent f48c40eb
No related branches found
No related tags found
No related merge requests found
...@@ -49,7 +49,7 @@ class Evaluator: ...@@ -49,7 +49,7 @@ class Evaluator:
self.use_sentiment = kwargs.get("use_sentiment", False) self.use_sentiment = kwargs.get("use_sentiment", False)
self.add_persona = kwargs.get("add_persona", False) self.add_persona = kwargs.get("add_persona", False)
self.emotion_mid = kwargs.get("emotion_mid", False) self.emotion_mid = kwargs.get("emotion_mid", False)
weight = kwargs.get("weight", None) self.emotion_weight = kwargs.get("weight", None)
self.sample = kwargs.get("sample", False) self.sample = kwargs.get("sample", False)
self.usr = UserActionPolicy( self.usr = UserActionPolicy(
...@@ -58,7 +58,7 @@ class Evaluator: ...@@ -58,7 +58,7 @@ class Evaluator:
use_sentiment=self.use_sentiment, use_sentiment=self.use_sentiment,
add_persona=self.add_persona, add_persona=self.add_persona,
emotion_mid=self.emotion_mid, emotion_mid=self.emotion_mid,
weight=weight) weight=self.emotion_weight)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
...@@ -169,6 +169,7 @@ class Evaluator: ...@@ -169,6 +169,7 @@ class Evaluator:
nlg_eval["golden"] = False nlg_eval["golden"] = False
nlg_eval["mode"] = mode nlg_eval["mode"] = mode
nlg_eval["emotion_weight"] = self.emotion_weight
nlg_eval["metrics"] = {} nlg_eval["metrics"] = {}
nlg_eval["dialog"] = self._transform_result() nlg_eval["dialog"] = self._transform_result()
...@@ -237,6 +238,7 @@ class Evaluator: ...@@ -237,6 +238,7 @@ class Evaluator:
scores[metric].append(s[metric]) scores[metric].append(s[metric])
result = {} result = {}
result["emotion_weight"] = self.emotion_weight
for metric in scores: for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric]) result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}") print(f"{metric}: {result[metric]}")
...@@ -276,8 +278,11 @@ class Evaluator: ...@@ -276,8 +278,11 @@ class Evaluator:
result["dialog"] = dialog_result result["dialog"] = dialog_result
basename = "semantic_evaluation_result" basename = "semantic_evaluation_result"
json.dump(result, open(os.path.join( json.dump(
self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w')) result,
open(os.path.join(self.model_checkpoint,
f"{self.time}-{self.dataset}-{basename}.json"), 'w'),
indent=2)
def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutral=False): def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutral=False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment