Skip to content
Snippets Groups Projects
Commit bfb58117 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

Merge branch 'GenTUS-training' into 'github_master'

Gen tus training

See merge request dsml/convlab/ConvLab3!42
parents 76fbb570 54006c3e
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ class UserActionPolicy(Policy):
print("change mode to semantic because only_action=True")
self.mode = "semantic"
self.max_in_len = 500
self.max_out_len = 50 if only_action else 200
self.max_out_len = 100 if only_action else 200
max_act_len = kwargs.get("max_act_len", 2)
print("max_act_len", max_act_len)
self.max_action_len = max_act_len
......
{
"model": {
"load_path": "convlab/policy/ppo/pretrained_models/supervised",
"load_path": "convlab/policy/ppo/pretrained_models/mle",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"batchsz": 500,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
......
{
"model": {
"load_path": "",
"load_path": "convlab/policy/ppo/pretrained_models/mle",
"use_pretrained_initialisation": false,
"pretrained_load_path": "",
"batchsz": 1000,
"batchsz": 500,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
"process_num": 4,
"sys_semantic_to_usr": false,
"num_eval_dialogues": 500
"num_eval_dialogues": 200
},
"vectorizer_sys": {
"uncertainty_vector_mul": {
......
{
"model": {
"load_path": "convlab/policy/mle/experiments/experiment_2022-05-23-14-08-43/save/supervised",
"load_path": "convlab/policy/ppo/pretrained_models/mle",
"use_pretrained_initialisation": false,
"pretrained_load_path": "",
"batchsz": 1000,
......@@ -35,7 +35,7 @@
"TUSPolicy": {
"class_path": "convlab.policy.tus.unify.TUS.UserPolicy",
"ini_params": {
"config": "convlab/policy/tus/unify/exp/all.json"
"config": "convlab/policy/tus/unify/exp/multiwoz.json"
}
}
},
......
......@@ -134,7 +134,7 @@ class UserActionPolicy(Policy):
goal = Goal(goal_list)
else:
goal = ABUS_Goal(self.goal_gen)
self.raw_gaol = goal.domain_goals
self.raw_goal = goal.domain_goals
goal_list = old_goal2list(goal.domain_goals)
goal = Goal(goal_list)
......@@ -411,7 +411,8 @@ class UserPolicy(Policy):
self.config = json.load(open(config))
else:
self.config = config
self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz'
self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}'
print("model_dir", self.config['model_dir'])
if not os.path.exists(self.config["model_dir"]):
# os.mkdir(self.config["model_dir"])
model_downloader(os.path.dirname(self.config["model_dir"]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment