diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoTUS/emoTUS.py index 8d10e7650bf54ae3e78a34def3e4a31283097ef2..588f087bcb18c790acd037b28ef081a28e456d7f 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoTUS/emoTUS.py @@ -23,6 +23,10 @@ class UserActionPolicy(GenTUSUserActionPolicy): print("use_sentiment: ", self.use_sentiment) print("add_persona: ", self.add_persona) print("emotion_mid: ", self.emotion_mid) + if not os.path.exists(os.path.dirname(model_checkpoint)): + os.makedirs(os.path.dirname(model_checkpoint)) + model_downloader(os.path.dirname(model_checkpoint), + "https://zenodo.org/record/7801525/files/EmoUS_default.zip") super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) weight = kwargs.get("weight", None) @@ -388,15 +392,16 @@ class UserActionPolicy(GenTUSUserActionPolicy): class UserPolicy(Policy): def __init__(self, - model_checkpoint, + model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default", sample=False, action_penalty=False, **kwargs): # self.config = config + print("emoTUS model checkpoint: ", model_checkpoint) if not os.path.exists(os.path.dirname(model_checkpoint)): - os.mkdir(os.path.dirname(model_checkpoint)) + os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), - "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip") + "https://zenodo.org/record/7801525/files/EmoUS_default.zip") only_action = False mode = "language" self.policy = UserActionPolicy( @@ -433,20 +438,20 @@ class UserPolicy(Policy): return self.policy.get_goal() return None + def get_emotion(self): + return self.policy.emotion + if __name__ == "__main__": import os - from convlab.dialog_agent import PipelineAgent - # from convlab.nlu.jointBERT.multiwoz import BERTNLU from convlab.util.custom_util import set_seed - use_sentiment, emotion_mid = True, True + use_sentiment, emotion_mid = False, False set_seed(0) # Test semantic level behaviour - model_checkpoint = 'convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17' usr_policy = UserPolicy( - model_checkpoint, + # model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default sample=True, use_sentiment=use_sentiment, emotion_mid=emotion_mid) @@ -456,10 +461,12 @@ if __name__ == "__main__": usr.init_session() print(usr.policy.get_goal()) - print(usr.response([])) + print(usr.response([]), usr.policy.get_emotion()) # print(usr.policy.policy.goal.status) print(usr.response([["inform", "restaurant", "area", "centre"], - ["request", "restaurant", "food", "?"]])) + ["request", "restaurant", "food", "?"]]), + usr.policy.get_emotion()) # print(usr.policy.policy.goal.status) - print(usr.response([["request", "restaurant", "price range", "?"]])) + print(usr.response([["request", "restaurant", "price range", "?"]]), + usr.policy.get_emotion()) # print(usr.policy.policy.goal.status) diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py index 5ac0d9803deebb48d17033e851bc28074f6a1983..c0ff690f01f52422fd64683964e46a18a51d0dcf 100644 --- a/convlab/policy/genTUS/stepGenTUS.py +++ b/convlab/policy/genTUS/stepGenTUS.py @@ -600,7 +600,7 @@ class UserPolicy(Policy): **kwargs): # self.config = config if not os.path.exists(os.path.dirname(model_checkpoint)): - os.mkdir(os.path.dirname(model_checkpoint)) + os.makedirs(os.path.dirname(model_checkpoint)) model_downloader(os.path.dirname(model_checkpoint), "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")