diff --git a/convlab/policy/emoTUS/sentiment.json b/convlab/policy/emoTUS/sentiment.json deleted file mode 100644 index 7d39df53b29133f0ef817326c37df481f3c936ed..0000000000000000000000000000000000000000 --- a/convlab/policy/emoTUS/sentiment.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "Neutral": 0, - "Negative": 1, - "Positive": 2 -} \ No newline at end of file diff --git a/convlab/policy/emoTUS/analysis.py b/convlab/policy/emoUS/analysis.py similarity index 98% rename from convlab/policy/emoTUS/analysis.py rename to convlab/policy/emoUS/analysis.py index 4862a2e2df6f9efa01e6f15df9ab27fd6a6ddbcf..37d0c7648c49257944506654d2d8aa7843bada21 100644 --- a/convlab/policy/emoTUS/analysis.py +++ b/convlab/policy/emoUS/analysis.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd -result_dir = "convlab/policy/emoTUS/result" +result_dir = "convlab/policy/emoUS/result" def arg_parser(): @@ -271,7 +271,7 @@ def loop(s0, s1, u1): def dict2csv(data): r = {} - emotion = json.load(open("convlab/policy/emoTUS/emotion.json")) + emotion = json.load(open("convlab/policy/emoUS/emotion.json")) for act, value in data.items(): temp = [0]*(len(emotion)+1) for emo, count in value.items(): diff --git a/convlab/policy/emoTUS/dialogue_collector.py b/convlab/policy/emoUS/dialogue_collector.py similarity index 100% rename from convlab/policy/emoTUS/dialogue_collector.py rename to convlab/policy/emoUS/dialogue_collector.py diff --git a/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json b/convlab/policy/emoUS/emoUS-BertNLU-RuleDST-RulePolicy.json similarity index 84% rename from convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json rename to convlab/policy/emoUS/emoUS-BertNLU-RuleDST-RulePolicy.json index 4bffcb314494dece1686f6e3d807cf82b3cf71ef..84d4dddb8e3b318040d2e2b29f47459dd46cddf5 100644 --- a/convlab/policy/emoTUS/emoTUS-BertNLU-RuleDST-RulePolicy.json +++ b/convlab/policy/emoUS/emoUS-BertNLU-RuleDST-RulePolicy.json @@ -1,6 +1,6 @@ { "model": { - "load_path": "convlab/policy/ppo/finished_experiments/history/NLGEmoTUS/experiment_2023-01-19-17-56-38/save/best_ppo", + "load_path": "convlab/policy/ppo/finished_experiments/history/NLGEmoUS/experiment_2023-01-19-17-56-38/save/best_ppo", "pretrained_load_path": "", "use_pretrained_initialisation": false, "batchsz": 200, @@ -41,10 +41,10 @@ "nlu_usr": {}, "dst_usr": {}, "policy_usr": { - "emoTUS": { - "class_path": "convlab.policy.emoTUS.emoTUS.UserPolicy", + "emoUS": { + "class_path": "convlab.policy.emoUS.emoUS.UserPolicy", "ini_params": { - "model_checkpoint": "convlab/policy/emoTUS/unify/experiments/EmoUS_emowoz+dialmage_0_1/23-01-23-15-03/", + "model_checkpoint": "convlab/policy/emoUS/unify/experiments/EmoUS_emowoz+dialmage_0_1/23-01-23-15-03/", "use_sentiment": false, "add_persona": true, "sample": false, @@ -53,4 +53,4 @@ } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/emoTUS/emoTUS.py b/convlab/policy/emoUS/emoUS.py similarity index 97% rename from convlab/policy/emoTUS/emoTUS.py rename to convlab/policy/emoUS/emoUS.py index ad029dd9cb1e9f9bcc94cde7754073cfe15d5977..165195ccab71cfbee5c4712f042a44e353556685 100644 --- a/convlab/policy/emoTUS/emoTUS.py +++ b/convlab/policy/emoUS/emoUS.py @@ -3,13 +3,13 @@ import json import torch -from convlab.policy.emoTUS.token_map import tokenMap -from convlab.policy.emoTUS.unify.knowledge_graph import KnowledgeGraph +from convlab.policy.emoUS.token_map import tokenMap +from convlab.policy.emoUS.unify.knowledge_graph import KnowledgeGraph from convlab.policy.genTUS.stepGenTUS import \ UserActionPolicy as GenTUSUserActionPolicy from convlab.policy.policy import Policy from convlab.util.custom_util import model_downloader -from convlab.policy.emoTUS.unify.Goal import Goal +from convlab.policy.emoUS.unify.Goal import Goal DEBUG = False @@ -39,7 +39,7 @@ class UserActionPolicy(GenTUSUserActionPolicy): dataset="emowoz", use_sentiment=self.use_sentiment, weight=weight) - data_emotion = json.load(open("convlab/policy/emoTUS/emotion.json")) + data_emotion = json.load(open("convlab/policy/emoUS/emotion.json")) self.emotion_list = [""]*len(data_emotion) for emotion, index in data_emotion.items(): self.emotion_list[index] = emotion @@ -406,13 +406,13 @@ class UserActionPolicy(GenTUSUserActionPolicy): class UserPolicy(Policy): def __init__(self, - model_checkpoint="convlab/policy/emoTUS/unify/default/EmoUS_default", + model_checkpoint="convlab/policy/emoUS/unify/default/EmoUS_default", mode="language", sample=False, action_penalty=False, **kwargs): # self.config = config - print("emoTUS model checkpoint: ", model_checkpoint) + print("emoUS model checkpoint: ", model_checkpoint) if sample: print("EmoUS will sample action, but emotion is always max") if not os.path.exists(os.path.dirname(model_checkpoint)): @@ -467,7 +467,7 @@ if __name__ == "__main__": set_seed(0) # Test semantic level behaviour usr_policy = UserPolicy( - # model_checkpoint, # default location = convlab/policy/emoTUS/unify/default/EmoUS_default + # model_checkpoint, # default location = convlab/policy/emoUS/unify/default/EmoUS_default mode="semantic", sample=True, use_sentiment=use_sentiment, diff --git a/convlab/policy/emoTUS/emotion.json b/convlab/policy/emoUS/emotion.json similarity index 100% rename from convlab/policy/emoTUS/emotion.json rename to convlab/policy/emoUS/emotion.json diff --git a/convlab/policy/emoTUS/emotion_eval.py b/convlab/policy/emoUS/emotion_eval.py similarity index 100% rename from convlab/policy/emoTUS/emotion_eval.py rename to convlab/policy/emoUS/emotion_eval.py diff --git a/convlab/policy/emoTUS/evaluate.py b/convlab/policy/emoUS/evaluate.py similarity index 99% rename from convlab/policy/emoTUS/evaluate.py rename to convlab/policy/emoUS/evaluate.py index 1ea73262ba0a7771633b9e93531e7fc3257ba778..ec1b8e4465c73f9737c7fada8fcaaae89e933784 100644 --- a/convlab/policy/emoTUS/evaluate.py +++ b/convlab/policy/emoUS/evaluate.py @@ -13,7 +13,7 @@ from sklearn import metrics from tqdm import tqdm from convlab.nlg.evaluate import fine_SER -from convlab.policy.emoTUS.emoTUS import UserActionPolicy +from convlab.policy.emoUS.emoUS import UserActionPolicy sys.path.append(os.path.dirname(os.path.dirname( os.path.dirname(os.path.abspath(__file__))))) @@ -76,7 +76,7 @@ class Evaluator: self.r["gen_sentiment"] = [] sent2emo = json.load( - open("convlab/policy/emoTUS/sent2emo.json")) + open("convlab/policy/emoUS/sent2emo.json")) self.emo2sent = {} for sent, emotions in sent2emo.items(): for emo in emotions: diff --git a/convlab/policy/emoTUS/self_bleu.py b/convlab/policy/emoUS/self_bleu.py similarity index 100% rename from convlab/policy/emoTUS/self_bleu.py rename to convlab/policy/emoUS/self_bleu.py diff --git a/convlab/policy/emoTUS/sent2emo.json b/convlab/policy/emoUS/sent2emo.json similarity index 100% rename from convlab/policy/emoTUS/sent2emo.json rename to convlab/policy/emoUS/sent2emo.json diff --git a/convlab/policy/emoTUS/token_map.py b/convlab/policy/emoUS/token_map.py similarity index 100% rename from convlab/policy/emoTUS/token_map.py rename to convlab/policy/emoUS/token_map.py diff --git a/convlab/policy/emoTUS/train_model.py b/convlab/policy/emoUS/train_model.py similarity index 99% rename from convlab/policy/emoTUS/train_model.py rename to convlab/policy/emoUS/train_model.py index 93bc762c19600d97246ed15ef5d72f44ca52d3e6..9883ee4503fe2ca043ddf8519259e53b02c8a864 100644 --- a/convlab/policy/emoTUS/train_model.py +++ b/convlab/policy/emoUS/train_model.py @@ -152,7 +152,7 @@ class TrainerHelper: self.tokenizer = tokenizer self.max_input_length = max_input_length self.max_target_length = max_target_length - self.base_name = "convlab/policy/emoTUS" + self.base_name = "convlab/policy/emoUS" self.dir_name = "" def _get_data_folder(self, model_type, data_name, dial_ids_order=0, split2ratio=1): @@ -186,7 +186,7 @@ class TrainerHelper: def remove_dialmage_action(self): self.dir_name = "fine_tune" - folder = "convlab/policy/emoTUS/unify/data" + folder = "convlab/policy/emoUS/unify/data" data_name = { "emowoz": "EmoUS_emowoz_0_1", "dialmage": "EmoUS_dialmage_0_1_emotion_only"} diff --git a/convlab/policy/emoTUS/unify/Goal.py b/convlab/policy/emoUS/unify/Goal.py similarity index 100% rename from convlab/policy/emoTUS/unify/Goal.py rename to convlab/policy/emoUS/unify/Goal.py diff --git a/convlab/policy/emoTUS/unify/build_data.py b/convlab/policy/emoUS/unify/build_data.py similarity index 96% rename from convlab/policy/emoTUS/unify/build_data.py rename to convlab/policy/emoUS/unify/build_data.py index 1b5d0bdc3dd2072e22d25ec4e7787e1facd71bab..8e84cf7d1e7639d5b11683fddccc5a90a8d9aeed 100644 --- a/convlab/policy/emoTUS/unify/build_data.py +++ b/convlab/policy/emoUS/unify/build_data.py @@ -5,7 +5,7 @@ from argparse import ArgumentParser from tqdm import tqdm -from convlab.policy.emoTUS.unify.Goal import Goal, emotion_info +from convlab.policy.emoUS.unify.Goal import Goal, emotion_info from convlab.policy.genTUS.unify.build_data import \ DataBuilder as GenTUSDataBuilder from convlab.policy.genTUS.unify.Goal import transform_data_act @@ -37,15 +37,15 @@ class DataBuilder(GenTUSDataBuilder): self.emotion_only = kwargs.get("emotion_only", False) self.emotion = {} - for emotion, index in json.load(open("convlab/policy/emoTUS/emotion.json")).items(): + for emotion, index in json.load(open("convlab/policy/emoUS/emotion.json")).items(): self.emotion[int(index)] = emotion if use_sentiment: self.sentiment = {} - for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items(): + for sentiment, index in json.load(open("convlab/policy/emoUS/sentiment.json")).items(): self.sentiment[int(index)] = sentiment self.sent2emo = json.load( - open("convlab/policy/emoTUS/sent2emo.json")) + open("convlab/policy/emoUS/sent2emo.json")) # TODO check excited distribution def _one_dialog(self, dialog, add_history=True, random_order=False, no_status=False): @@ -154,7 +154,7 @@ TODO if __name__ == "__main__": args = arg_parser() - base_name = "convlab/policy/emoTUS/unify/data" + base_name = "convlab/policy/emoUS/unify/data" dir_name = f"{args.dataset}_{args.dial_ids_order}_{args.split2ratio}" use_sentiment = args.use_sentiment diff --git a/convlab/policy/emoTUS/unify/knowledge_graph.py b/convlab/policy/emoUS/unify/knowledge_graph.py similarity index 94% rename from convlab/policy/emoTUS/unify/knowledge_graph.py rename to convlab/policy/emoUS/unify/knowledge_graph.py index e1c39c00b1106d525564e123ab1e651b6161c2a1..7b68c2fe55dc3a0161dd2b1a6136acf0f20cbe41 100644 --- a/convlab/policy/emoTUS/unify/knowledge_graph.py +++ b/convlab/policy/emoUS/unify/knowledge_graph.py @@ -19,7 +19,7 @@ class KnowledgeGraph(GenTUSKnowledgeGraph): if use_sentiment: data_sentiment = json.load( - open("convlab/policy/emoTUS/sentiment.json")) + open("convlab/policy/emoUS/sentiment.json")) self.kg_map = {"sentiment": tokenMap(tokenizer=self.tokenizer)} self.sentiment = [""]*len(data_sentiment) for sentiment, index in data_sentiment.items(): @@ -28,14 +28,14 @@ class KnowledgeGraph(GenTUSKnowledgeGraph): self.kg_map["sentiment"].add_token(sentiment, sentiment) self.kg_map[sentiment] = tokenMap(tokenizer=self.tokenizer) self.sent2emo = json.load( - open("convlab/policy/emoTUS/sent2emo.json")) + open("convlab/policy/emoUS/sent2emo.json")) for sent in self.sent2emo: for emo in self.sent2emo[sent]: self.kg_map[sent].add_token(emo, emo) else: data_emotion = json.load( - open("convlab/policy/emoTUS/emotion.json")) + open("convlab/policy/emoUS/emotion.json")) self.emotion = [""]*len(data_emotion) for emotion, index in data_emotion.items(): self.emotion[index] = emotion diff --git a/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/emoUS-BertNLU-RuleDST-PPOPolicy.json similarity index 88% rename from convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json rename to convlab/policy/ppo/emoUS-BertNLU-RuleDST-PPOPolicy.json index e756edf37e322202f60dcc5276e11c89df6eb40f..eba5bde86b5387c96dbf2b6be2d89c3e0bb71cdb 100644 --- a/convlab/policy/ppo/emoTUS-BertNLU-RuleDST-PPOPolicy.json +++ b/convlab/policy/ppo/emoUS-BertNLU-RuleDST-PPOPolicy.json @@ -41,10 +41,10 @@ "nlu_usr": {}, "dst_usr": {}, "policy_usr": { - "emoTUS": { - "class_path": "convlab.policy.emoTUS.emoTUS.UserPolicy", + "emoUS": { + "class_path": "convlab.policy.emoUS.emoUS.UserPolicy", "ini_params": { - "model_checkpoint": "convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", + "model_checkpoint": "convlab/policy/emoUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", "character": "usr", "mode": "language", "only_action": false, @@ -54,4 +54,4 @@ } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/ppo/sigir23/Sample-emoTUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/sigir23/Sample-emoTUS-BertNLU-RuleDST-PPOPolicy.json index 0c416815736a1cfd454aeb1e885d46be1ba2c9a7..9e069da530d756239801dcecb98b4d38cf2cee06 100644 --- a/convlab/policy/ppo/sigir23/Sample-emoTUS-BertNLU-RuleDST-PPOPolicy.json +++ b/convlab/policy/ppo/sigir23/Sample-emoTUS-BertNLU-RuleDST-PPOPolicy.json @@ -42,9 +42,9 @@ "dst_usr": {}, "policy_usr": { "RulePolicy": { - "class_path": "convlab.policy.emoTUS.emoTUS.UserPolicy", + "class_path": "convlab.policy.emoUS.emoUS.UserPolicy", "ini_params": { - "model_checkpoint": "convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", + "model_checkpoint": "convlab/policy/emoUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", "character": "usr", "mode": "language", "only_action": false, @@ -54,4 +54,4 @@ } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/ppo/sigir23/emoTUS-BertNLU-RuleDST-PPOPolicy.json b/convlab/policy/ppo/sigir23/emoTUS-BertNLU-RuleDST-PPOPolicy.json index 2b99f55dcf65f80335f0d0cb7e0192abbacf9f47..bbbe003c78eef9ea9b195bbea1240cd1868fb2e9 100644 --- a/convlab/policy/ppo/sigir23/emoTUS-BertNLU-RuleDST-PPOPolicy.json +++ b/convlab/policy/ppo/sigir23/emoTUS-BertNLU-RuleDST-PPOPolicy.json @@ -42,9 +42,9 @@ "dst_usr": {}, "policy_usr": { "RulePolicy": { - "class_path": "convlab.policy.emoTUS.emoTUS.UserPolicy", + "class_path": "convlab.policy.emoUS.emoUS.UserPolicy", "ini_params": { - "model_checkpoint": "convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", + "model_checkpoint": "convlab/policy/emoUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17", "character": "usr", "mode": "language", "only_action": false, @@ -54,4 +54,4 @@ } }, "usr_nlg": {} -} +} \ No newline at end of file diff --git a/convlab/policy/ussT5/emowoz_evaluate.py b/convlab/policy/ussT5/emowoz_evaluate.py index 2e8a07d5a20bf50bde5dd2f103385a490de3f4a7..5b9cd8a329976f8afb3be3c695f5a8bc26d995fd 100644 --- a/convlab/policy/ussT5/emowoz_evaluate.py +++ b/convlab/policy/ussT5/emowoz_evaluate.py @@ -28,7 +28,7 @@ def arg_parser(): def build_data(raw_data): sentiments = {} - for sentiment, index in json.load(open("convlab/policy/emoTUS/sentiment.json")).items(): + for sentiment, index in json.load(open("convlab/policy/emoUS/sentiment.json")).items(): sentiments[int(index)] = sentiment data = {"input_text": [], "target_text": []} for prefix in ["satisfaction score: ", "action prediction: ", "utterance generation: "]: