Skip to content
Snippets Groups Projects
Commit 58bd6ce2 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

wip

parent f48c40eb
Branches
Tags
No related merge requests found
......@@ -49,7 +49,7 @@ class Evaluator:
self.use_sentiment = kwargs.get("use_sentiment", False)
self.add_persona = kwargs.get("add_persona", False)
self.emotion_mid = kwargs.get("emotion_mid", False)
weight = kwargs.get("weight", None)
self.emotion_weight = kwargs.get("weight", None)
self.sample = kwargs.get("sample", False)
self.usr = UserActionPolicy(
......@@ -58,7 +58,7 @@ class Evaluator:
use_sentiment=self.use_sentiment,
add_persona=self.add_persona,
emotion_mid=self.emotion_mid,
weight=weight)
weight=self.emotion_weight)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
......@@ -169,6 +169,7 @@ class Evaluator:
nlg_eval["golden"] = False
nlg_eval["mode"] = mode
nlg_eval["emotion_weight"] = self.emotion_weight
nlg_eval["metrics"] = {}
nlg_eval["dialog"] = self._transform_result()
......@@ -237,6 +238,7 @@ class Evaluator:
scores[metric].append(s[metric])
result = {}
result["emotion_weight"] = self.emotion_weight
for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}")
......@@ -276,8 +278,11 @@ class Evaluator:
result["dialog"] = dialog_result
basename = "semantic_evaluation_result"
json.dump(result, open(os.path.join(
self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w'))
json.dump(
result,
open(os.path.join(self.model_checkpoint,
f"{self.time}-{self.dataset}-{basename}.json"), 'w'),
indent=2)
def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutral=False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment