From 844bc8457912c23962d8db1605d99f2304ad6f8d Mon Sep 17 00:00:00 2001
From: Hsien-Chin Lin <linh@hhu.de>
Date: Thu, 13 Apr 2023 14:26:20 +0200
Subject: [PATCH] update

---
 convlab/policy/emoUS/emoUS.py        | 10 +++++-----
 convlab/policy/emoUS/emotion_eval.py |  6 +++---
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/convlab/policy/emoUS/emoUS.py b/convlab/policy/emoUS/emoUS.py
index 165195cc..2f0c24b7 100644
--- a/convlab/policy/emoUS/emoUS.py
+++ b/convlab/policy/emoUS/emoUS.py
@@ -464,7 +464,7 @@ if __name__ == "__main__":
     import time
 
     use_sentiment, emotion_mid = False, False
-    set_seed(0)
+    set_seed(100)
     # Test semantic level behaviour
     usr_policy = UserPolicy(
         # model_checkpoint, # default location = convlab/policy/emoUS/unify/default/EmoUS_default
@@ -476,15 +476,15 @@ if __name__ == "__main__":
     usr_nlu = None  # BERTNLU()
     usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
     usr.init_session()
+    usr.init_session()
     print(usr.policy.get_goal())
     start = time.time()
-    print(usr.response([]), usr.policy.get_emotion())
+
     # print(usr.policy.policy.goal.status)
-    print(usr.response([["inform", "restaurant", "area", "centre"],
-                        ["request", "restaurant", "food", "?"]]),
+    print(usr.response([['inform', 'train', 'day', 'saturday']]),
           usr.policy.get_emotion())
     # print(usr.policy.policy.goal.status)
-    print(usr.response([["request", "restaurant", "price range", "?"]]),
+    print(usr.response([]),
           usr.policy.get_emotion())
     end = time.time()
     print("-"*50)
diff --git a/convlab/policy/emoUS/emotion_eval.py b/convlab/policy/emoUS/emotion_eval.py
index be0c383c..ee9fc616 100644
--- a/convlab/policy/emoUS/emotion_eval.py
+++ b/convlab/policy/emoUS/emotion_eval.py
@@ -11,7 +11,7 @@ from sklearn import metrics
 from tqdm import tqdm
 
 from convlab.nlg.evaluate import fine_SER
-from convlab.policy.emoTUS.emoTUS import UserActionPolicy
+from convlab.policy.emoUS.emoUS import UserActionPolicy
 
 sys.path.append(os.path.dirname(os.path.dirname(
     os.path.dirname(os.path.abspath(__file__)))))
@@ -72,13 +72,13 @@ class Evaluator:
 
         self.emotion_list = []
 
-        for emotion in json.load(open("convlab/policy/emoTUS/emotion.json")):
+        for emotion in json.load(open("convlab/policy/emoUS/emotion.json")):
             self.emotion_list.append(emotion)
             self.r[f"{emotion}_acts"] = []
             self.r[f"{emotion}_utts"] = []
 
         sent2emo = json.load(
-            open("convlab/policy/emoTUS/sent2emo.json"))
+            open("convlab/policy/emoUS/sent2emo.json"))
         self.emo2sent = {}
         for sent, emotions in sent2emo.items():
             for emo in emotions:
-- 
GitLab