diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoTUS/evaluate.py index 60a15e05c08cca45ada3600faeb3bf02f8fbe2ea..9d587e4a57eb2d71bb49b023a2d764b711eb43d9 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoTUS/evaluate.py @@ -49,7 +49,7 @@ class Evaluator: self.use_sentiment = kwargs.get("use_sentiment", False) self.add_persona = kwargs.get("add_persona", False) self.emotion_mid = kwargs.get("emotion_mid", False) - weight = kwargs.get("weight", None) + self.emotion_weight = kwargs.get("weight", None) self.sample = kwargs.get("sample", False) self.usr = UserActionPolicy( @@ -58,7 +58,7 @@ class Evaluator: use_sentiment=self.use_sentiment, add_persona=self.add_persona, emotion_mid=self.emotion_mid, - weight=weight) + weight=self.emotion_weight) self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin")) @@ -169,6 +169,7 @@ class Evaluator: nlg_eval["golden"] = False nlg_eval["mode"] = mode + nlg_eval["emotion_weight"] = self.emotion_weight nlg_eval["metrics"] = {} nlg_eval["dialog"] = self._transform_result() @@ -237,6 +238,7 @@ class Evaluator: scores[metric].append(s[metric]) result = {} + result["emotion_weight"] = self.emotion_weight for metric in scores: result[metric] = sum(scores[metric])/len(scores[metric]) print(f"{metric}: {result[metric]}") @@ -276,8 +278,11 @@ class Evaluator: result["dialog"] = dialog_result basename = "semantic_evaluation_result" - json.dump(result, open(os.path.join( - self.model_checkpoint, f"{self.time}-{self.dataset}-{basename}.json"), 'w')) + json.dump( + result, + open(os.path.join(self.model_checkpoint, + f"{self.time}-{self.dataset}-{basename}.json"), 'w'), + indent=2) def emotion_score(golden_emotions, gen_emotions, dirname=".", time="", no_neutral=False):