Skip to content
Snippets Groups Projects
Commit 2d2f3332 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

update model downloading

parent f99898c1
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,10 @@ class UserActionPolicy(GenTUSUserActionPolicy):
print("use_sentiment: ", self.use_sentiment)
print("add_persona: ", self.add_persona)
print("emotion_mid: ", self.emotion_mid)
if not os.path.exists(os.path.dirname(model_checkpoint)):
os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7801525/files/EmoUS_default.zip")
super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs)
weight = kwargs.get("weight", None)
......@@ -388,15 +392,16 @@ class UserActionPolicy(GenTUSUserActionPolicy):
class UserPolicy(Policy):
def __init__(self,
model_checkpoint,
model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default",
sample=False,
action_penalty=False,
**kwargs):
# self.config = config
print("emoTUS model checkpoint: ", model_checkpoint)
if not os.path.exists(os.path.dirname(model_checkpoint)):
os.mkdir(os.path.dirname(model_checkpoint))
os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
"https://zenodo.org/record/7801525/files/EmoUS_default.zip")
only_action = False
mode = "language"
self.policy = UserActionPolicy(
......@@ -433,20 +438,20 @@ class UserPolicy(Policy):
return self.policy.get_goal()
return None
def get_emotion(self):
return self.policy.emotion
if __name__ == "__main__":
import os
from convlab.dialog_agent import PipelineAgent
# from convlab.nlu.jointBERT.multiwoz import BERTNLU
from convlab.util.custom_util import set_seed
use_sentiment, emotion_mid = True, True
use_sentiment, emotion_mid = False, False
set_seed(0)
# Test semantic level behaviour
model_checkpoint = 'convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17'
usr_policy = UserPolicy(
model_checkpoint,
# model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default
sample=True,
use_sentiment=use_sentiment,
emotion_mid=emotion_mid)
......@@ -456,10 +461,12 @@ if __name__ == "__main__":
usr.init_session()
print(usr.policy.get_goal())
print(usr.response([]))
print(usr.response([]), usr.policy.get_emotion())
# print(usr.policy.policy.goal.status)
print(usr.response([["inform", "restaurant", "area", "centre"],
["request", "restaurant", "food", "?"]]))
["request", "restaurant", "food", "?"]]),
usr.policy.get_emotion())
# print(usr.policy.policy.goal.status)
print(usr.response([["request", "restaurant", "price range", "?"]]))
print(usr.response([["request", "restaurant", "price range", "?"]]),
usr.policy.get_emotion())
# print(usr.policy.policy.goal.status)
......@@ -600,7 +600,7 @@ class UserPolicy(Policy):
**kwargs):
# self.config = config
if not os.path.exists(os.path.dirname(model_checkpoint)):
os.mkdir(os.path.dirname(model_checkpoint))
os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment