Skip to content
Snippets Groups Projects
Commit 2d2f3332 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

update model downloading

parent f99898c1
No related branches found
No related tags found
No related merge requests found
...@@ -23,6 +23,10 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -23,6 +23,10 @@ class UserActionPolicy(GenTUSUserActionPolicy):
print("use_sentiment: ", self.use_sentiment) print("use_sentiment: ", self.use_sentiment)
print("add_persona: ", self.add_persona) print("add_persona: ", self.add_persona)
print("emotion_mid: ", self.emotion_mid) print("emotion_mid: ", self.emotion_mid)
if not os.path.exists(os.path.dirname(model_checkpoint)):
os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7801525/files/EmoUS_default.zip")
super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs)
weight = kwargs.get("weight", None) weight = kwargs.get("weight", None)
...@@ -388,15 +392,16 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -388,15 +392,16 @@ class UserActionPolicy(GenTUSUserActionPolicy):
class UserPolicy(Policy): class UserPolicy(Policy):
def __init__(self, def __init__(self,
model_checkpoint, model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default",
sample=False, sample=False,
action_penalty=False, action_penalty=False,
**kwargs): **kwargs):
# self.config = config # self.config = config
print("emoTUS model checkpoint: ", model_checkpoint)
if not os.path.exists(os.path.dirname(model_checkpoint)): if not os.path.exists(os.path.dirname(model_checkpoint)):
os.mkdir(os.path.dirname(model_checkpoint)) os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint), model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7372442/files/multiwoz21-exp.zip") "https://zenodo.org/record/7801525/files/EmoUS_default.zip")
only_action = False only_action = False
mode = "language" mode = "language"
self.policy = UserActionPolicy( self.policy = UserActionPolicy(
...@@ -433,20 +438,20 @@ class UserPolicy(Policy): ...@@ -433,20 +438,20 @@ class UserPolicy(Policy):
return self.policy.get_goal() return self.policy.get_goal()
return None return None
def get_emotion(self):
return self.policy.emotion
if __name__ == "__main__": if __name__ == "__main__":
import os import os
from convlab.dialog_agent import PipelineAgent from convlab.dialog_agent import PipelineAgent
# from convlab.nlu.jointBERT.multiwoz import BERTNLU
from convlab.util.custom_util import set_seed from convlab.util.custom_util import set_seed
use_sentiment, emotion_mid = True, True use_sentiment, emotion_mid = False, False
set_seed(0) set_seed(0)
# Test semantic level behaviour # Test semantic level behaviour
model_checkpoint = 'convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17'
usr_policy = UserPolicy( usr_policy = UserPolicy(
model_checkpoint, # model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default
sample=True, sample=True,
use_sentiment=use_sentiment, use_sentiment=use_sentiment,
emotion_mid=emotion_mid) emotion_mid=emotion_mid)
...@@ -456,10 +461,12 @@ if __name__ == "__main__": ...@@ -456,10 +461,12 @@ if __name__ == "__main__":
usr.init_session() usr.init_session()
print(usr.policy.get_goal()) print(usr.policy.get_goal())
print(usr.response([])) print(usr.response([]), usr.policy.get_emotion())
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
print(usr.response([["inform", "restaurant", "area", "centre"], print(usr.response([["inform", "restaurant", "area", "centre"],
["request", "restaurant", "food", "?"]])) ["request", "restaurant", "food", "?"]]),
usr.policy.get_emotion())
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
print(usr.response([["request", "restaurant", "price range", "?"]])) print(usr.response([["request", "restaurant", "price range", "?"]]),
usr.policy.get_emotion())
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
...@@ -600,7 +600,7 @@ class UserPolicy(Policy): ...@@ -600,7 +600,7 @@ class UserPolicy(Policy):
**kwargs): **kwargs):
# self.config = config # self.config = config
if not os.path.exists(os.path.dirname(model_checkpoint)): if not os.path.exists(os.path.dirname(model_checkpoint)):
os.mkdir(os.path.dirname(model_checkpoint)) os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint), model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7372442/files/multiwoz21-exp.zip") "https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment