Skip to content
Snippets Groups Projects
Commit 4816b8fe authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

semantic mode

parent 2d2f3332
No related branches found
No related tags found
No related merge requests found
...@@ -15,19 +15,23 @@ DEBUG = False ...@@ -15,19 +15,23 @@ DEBUG = False
class UserActionPolicy(GenTUSUserActionPolicy): class UserActionPolicy(GenTUSUserActionPolicy):
def __init__(self, model_checkpoint, mode="language", only_action=False, max_turn=40, **kwargs): def __init__(self, model_checkpoint, mode="language", max_turn=40, **kwargs):
self.use_sentiment = kwargs.get("use_sentiment", False) self.use_sentiment = kwargs.get("use_sentiment", False)
self.add_persona = kwargs.get("add_persona", True) self.add_persona = kwargs.get("add_persona", True)
self.emotion_mid = kwargs.get("emotion_mid", False) self.emotion_mid = kwargs.get("emotion_mid", False)
print("===== model status =====")
print("use_sentiment: ", self.use_sentiment)
print("add_persona: ", self.add_persona)
print("emotion_mid: ", self.emotion_mid)
if not os.path.exists(os.path.dirname(model_checkpoint)): if not os.path.exists(os.path.dirname(model_checkpoint)):
os.makedirs(os.path.dirname(model_checkpoint)) os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint), model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7801525/files/EmoUS_default.zip") "https://zenodo.org/record/7801525/files/EmoUS_default.zip")
if mode == "language":
only_action = False
elif mode == "semantic":
only_action = True
else:
raise ValueError("mode should be language or semantic")
super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs) super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs)
weight = kwargs.get("weight", None) weight = kwargs.get("weight", None)
self.kg = KnowledgeGraph( self.kg = KnowledgeGraph(
...@@ -96,7 +100,10 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -96,7 +100,10 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent) raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
output = self._parse_output(raw_output) output = self._parse_output(raw_output)
self.semantic_action = self._remove_illegal_action(output["action"]) self.semantic_action = self._remove_illegal_action(output["action"])
if not self.only_action:
self.utterance = output["text"] self.utterance = output["text"]
self.emotion = output["emotion"] self.emotion = output["emotion"]
if self.use_sentiment: if self.use_sentiment:
self.sentiment = output["sentiment"] self.sentiment = output["sentiment"]
...@@ -112,6 +119,9 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -112,6 +119,9 @@ class UserActionPolicy(GenTUSUserActionPolicy):
del inputs del inputs
if self.only_action:
return self.semantic_action
return self.utterance return self.utterance
def _parse_output(self, in_str): def _parse_output(self, in_str):
...@@ -230,10 +240,14 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -230,10 +240,14 @@ class UserActionPolicy(GenTUSUserActionPolicy):
elif not self.use_sentiment and self.emotion_mid: elif not self.use_sentiment and self.emotion_mid:
pos = self._act_emo( pos = self._act_emo(
pos, model_input, mode, emotion_mode, allow_general_intent) pos, model_input, mode, emotion_mode, allow_general_intent)
else: else: # defalut method
pos = self._emo_act( pos = self._emo_act(
pos, model_input, mode, emotion_mode, allow_general_intent) pos, model_input, mode, emotion_mode, allow_general_intent)
if self.only_action:
# return semantic action. Don't need to generate text
return self.vector.decode(self.seq[0, :pos])
pos = self._update_seq(self.token_map.get_id("start_text"), pos) pos = self._update_seq(self.token_map.get_id("start_text"), pos)
text = self._get_text(model_input, pos) text = self._get_text(model_input, pos)
...@@ -393,6 +407,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): ...@@ -393,6 +407,7 @@ class UserActionPolicy(GenTUSUserActionPolicy):
class UserPolicy(Policy): class UserPolicy(Policy):
def __init__(self, def __init__(self,
model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default", model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default",
mode="language",
sample=False, sample=False,
action_penalty=False, action_penalty=False,
**kwargs): **kwargs):
...@@ -402,12 +417,10 @@ class UserPolicy(Policy): ...@@ -402,12 +417,10 @@ class UserPolicy(Policy):
os.makedirs(os.path.dirname(model_checkpoint)) os.makedirs(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint), model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7801525/files/EmoUS_default.zip") "https://zenodo.org/record/7801525/files/EmoUS_default.zip")
only_action = False
mode = "language"
self.policy = UserActionPolicy( self.policy = UserActionPolicy(
model_checkpoint, model_checkpoint,
mode=mode, mode=mode,
only_action=only_action,
action_penalty=action_penalty, action_penalty=action_penalty,
**kwargs) **kwargs)
self.policy.load(os.path.join( self.policy.load(os.path.join(
...@@ -446,12 +459,14 @@ if __name__ == "__main__": ...@@ -446,12 +459,14 @@ if __name__ == "__main__":
import os import os
from convlab.dialog_agent import PipelineAgent from convlab.dialog_agent import PipelineAgent
from convlab.util.custom_util import set_seed from convlab.util.custom_util import set_seed
import time
use_sentiment, emotion_mid = False, False use_sentiment, emotion_mid = False, False
set_seed(0) set_seed(0)
# Test semantic level behaviour # Test semantic level behaviour
usr_policy = UserPolicy( usr_policy = UserPolicy(
# model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default # model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default
mode="semantic",
sample=True, sample=True,
use_sentiment=use_sentiment, use_sentiment=use_sentiment,
emotion_mid=emotion_mid) emotion_mid=emotion_mid)
...@@ -460,7 +475,7 @@ if __name__ == "__main__": ...@@ -460,7 +475,7 @@ if __name__ == "__main__":
usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
usr.init_session() usr.init_session()
print(usr.policy.get_goal()) print(usr.policy.get_goal())
start = time.time()
print(usr.response([]), usr.policy.get_emotion()) print(usr.response([]), usr.policy.get_emotion())
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
print(usr.response([["inform", "restaurant", "area", "centre"], print(usr.response([["inform", "restaurant", "area", "centre"],
...@@ -469,4 +484,7 @@ if __name__ == "__main__": ...@@ -469,4 +484,7 @@ if __name__ == "__main__":
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
print(usr.response([["request", "restaurant", "price range", "?"]]), print(usr.response([["request", "restaurant", "price range", "?"]]),
usr.policy.get_emotion()) usr.policy.get_emotion())
end = time.time()
print("-"*50)
print("time: ", end - start)
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment