diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 8590317da67d3d03eb73b24e4a73d17fbf998bf5..dadfc923e48fdcefbb18c210f250d81f52e7022a 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -17,8 +17,12 @@ DEBUG = False class UserActionPolicy(GenTUSUserActionPolicy): def __init__(self, model_checkpoint, mode="language", only_action=False, max_turn=40, **kwargs): self.use_sentiment = kwargs.get("use_sentiment", False) - self.add_persona = kwargs.get("add_persona", False) + self.add_persona = kwargs.get("add_persona", True) self.emotion_mid = kwargs.get("emotion_mid", False) + print("===== model status =====") + print("use_sentiment: ", self.use_sentiment) + print("add_persona: ", self.add_persona) + print("emotion_mid: ", self.emotion_mid) super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) weight = kwargs.get("weight", None)