From 844bc8457912c23962d8db1605d99f2304ad6f8d Mon Sep 17 00:00:00 2001 From: Hsien-Chin Lin <linh@hhu.de> Date: Thu, 13 Apr 2023 14:26:20 +0200 Subject: [PATCH] update --- convlab/policy/emoUS/emoUS.py | 10 +++++----- convlab/policy/emoUS/emotion_eval.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/convlab/policy/emoUS/emoUS.py b/convlab/policy/emoUS/emoUS.py index 165195cc..2f0c24b7 100644 --- a/convlab/policy/emoUS/emoUS.py +++ b/convlab/policy/emoUS/emoUS.py @@ -464,7 +464,7 @@ if __name__ == "__main__": import time use_sentiment, emotion_mid = False, False - set_seed(0) + set_seed(100) # Test semantic level behaviour usr_policy = UserPolicy( # model_checkpoint, # default location = convlab/policy/emoUS/unify/default/EmoUS_default @@ -476,15 +476,15 @@ if __name__ == "__main__": usr_nlu = None # BERTNLU() usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') usr.init_session() + usr.init_session() print(usr.policy.get_goal()) start = time.time() - print(usr.response([]), usr.policy.get_emotion()) + # print(usr.policy.policy.goal.status) - print(usr.response([["inform", "restaurant", "area", "centre"], - ["request", "restaurant", "food", "?"]]), + print(usr.response([['inform', 'train', 'day', 'saturday']]), usr.policy.get_emotion()) # print(usr.policy.policy.goal.status) - print(usr.response([["request", "restaurant", "price range", "?"]]), + print(usr.response([]), usr.policy.get_emotion()) end = time.time() print("-"*50) diff --git a/convlab/policy/emoUS/emotion_eval.py b/convlab/policy/emoUS/emotion_eval.py index be0c383c..ee9fc616 100644 --- a/convlab/policy/emoUS/emotion_eval.py +++ b/convlab/policy/emoUS/emotion_eval.py @@ -11,7 +11,7 @@ from sklearn import metrics from tqdm import tqdm from convlab.nlg.evaluate import fine_SER -from convlab.policy.emoTUS.emoTUS import UserActionPolicy +from convlab.policy.emoUS.emoUS import UserActionPolicy sys.path.append(os.path.dirname(os.path.dirname( os.path.dirname(os.path.abspath(__file__))))) @@ -72,13 +72,13 @@ class Evaluator: self.emotion_list = [] - for emotion in json.load(open("convlab/policy/emoTUS/emotion.json")): + for emotion in json.load(open("convlab/policy/emoUS/emotion.json")): self.emotion_list.append(emotion) self.r[f"{emotion}_acts"] = [] self.r[f"{emotion}_utts"] = [] sent2emo = json.load( - open("convlab/policy/emoTUS/sent2emo.json")) + open("convlab/policy/emoUS/sent2emo.json")) self.emo2sent = {} for sent, emotions in sent2emo.items(): for emo in emotions: -- GitLab