Skip to content
Snippets Groups Projects
Commit 844bc845 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

update

parent fdf13e1f
No related branches found
No related tags found
No related merge requests found
...@@ -464,7 +464,7 @@ if __name__ == "__main__": ...@@ -464,7 +464,7 @@ if __name__ == "__main__":
import time import time
use_sentiment, emotion_mid = False, False use_sentiment, emotion_mid = False, False
set_seed(0) set_seed(100)
# Test semantic level behaviour # Test semantic level behaviour
usr_policy = UserPolicy( usr_policy = UserPolicy(
# model_checkpoint, # default location = convlab/policy/emoUS/unify/default/EmoUS_default # model_checkpoint, # default location = convlab/policy/emoUS/unify/default/EmoUS_default
...@@ -476,15 +476,15 @@ if __name__ == "__main__": ...@@ -476,15 +476,15 @@ if __name__ == "__main__":
usr_nlu = None # BERTNLU() usr_nlu = None # BERTNLU()
usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user') usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
usr.init_session() usr.init_session()
usr.init_session()
print(usr.policy.get_goal()) print(usr.policy.get_goal())
start = time.time() start = time.time()
print(usr.response([]), usr.policy.get_emotion())
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
print(usr.response([["inform", "restaurant", "area", "centre"], print(usr.response([['inform', 'train', 'day', 'saturday']]),
["request", "restaurant", "food", "?"]]),
usr.policy.get_emotion()) usr.policy.get_emotion())
# print(usr.policy.policy.goal.status) # print(usr.policy.policy.goal.status)
print(usr.response([["request", "restaurant", "price range", "?"]]), print(usr.response([]),
usr.policy.get_emotion()) usr.policy.get_emotion())
end = time.time() end = time.time()
print("-"*50) print("-"*50)
......
...@@ -11,7 +11,7 @@ from sklearn import metrics ...@@ -11,7 +11,7 @@ from sklearn import metrics
from tqdm import tqdm from tqdm import tqdm
from convlab.nlg.evaluate import fine_SER from convlab.nlg.evaluate import fine_SER
from convlab.policy.emoTUS.emoTUS import UserActionPolicy from convlab.policy.emoUS.emoUS import UserActionPolicy
sys.path.append(os.path.dirname(os.path.dirname( sys.path.append(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))))) os.path.dirname(os.path.abspath(__file__)))))
...@@ -72,13 +72,13 @@ class Evaluator: ...@@ -72,13 +72,13 @@ class Evaluator:
self.emotion_list = [] self.emotion_list = []
for emotion in json.load(open("convlab/policy/emoTUS/emotion.json")): for emotion in json.load(open("convlab/policy/emoUS/emotion.json")):
self.emotion_list.append(emotion) self.emotion_list.append(emotion)
self.r[f"{emotion}_acts"] = [] self.r[f"{emotion}_acts"] = []
self.r[f"{emotion}_utts"] = [] self.r[f"{emotion}_utts"] = []
sent2emo = json.load( sent2emo = json.load(
open("convlab/policy/emoTUS/sent2emo.json")) open("convlab/policy/emoUS/sent2emo.json"))
self.emo2sent = {} self.emo2sent = {}
for sent, emotions in sent2emo.items(): for sent, emotions in sent2emo.items():
for emo in emotions: for emo in emotions:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment