diff --git a/convlab/policy/genTUS/stepGenTUS.py b/convlab/policy/genTUS/stepGenTUS.py index f16c0ebeeeaaba50f1739aea6c4db40eb81f8d29..1d8e89067491e14e07eda12e17fc1facbafdab8b 100644 --- a/convlab/policy/genTUS/stepGenTUS.py +++ b/convlab/policy/genTUS/stepGenTUS.py @@ -26,7 +26,7 @@ class UserActionPolicy(Policy): print("change mode to semantic because only_action=True") self.mode = "semantic" self.max_in_len = 500 - self.max_out_len = 50 if only_action else 200 + self.max_out_len = 100 if only_action else 200 max_act_len = kwargs.get("max_act_len", 2) print("max_act_len", max_act_len) self.max_action_len = max_act_len diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json index 5bf65c9f97f951a367a0abf461e9aa9172d64021..0e8774e20898c2855f368127f8e14e193ac2c21d 100644 --- a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json +++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json @@ -1,9 +1,9 @@ { "model": { - "load_path": "convlab/policy/ppo/pretrained_models/supervised", + "load_path": "convlab/policy/ppo/pretrained_models/mle", "pretrained_load_path": "", "use_pretrained_initialisation": false, - "batchsz": 1000, + "batchsz": 500, "seed": 0, "epoch": 50, "eval_frequency": 5, @@ -41,4 +41,4 @@ } }, "usr_nlg": {} -} \ No newline at end of file +} diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json index a4a24598d7eac283bcc05d21b0f695cad90b6433..7525899ccc285871fbe5685c8b731adcb318eb1a 100644 --- a/convlab/policy/ppo/semantic_level_config.json +++ b/convlab/policy/ppo/semantic_level_config.json @@ -1,15 +1,15 @@ { "model": { - "load_path": "", + "load_path": "convlab/policy/ppo/pretrained_models/mle", "use_pretrained_initialisation": false, "pretrained_load_path": "", - "batchsz": 1000, + "batchsz": 500, "seed": 0, "epoch": 50, "eval_frequency": 5, "process_num": 4, "sys_semantic_to_usr": false, - "num_eval_dialogues": 500 + "num_eval_dialogues": 200 }, "vectorizer_sys": { "uncertainty_vector_mul": { @@ -40,4 +40,4 @@ } }, "usr_nlg": {} -} \ No newline at end of file +} diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json index cc0216122dd473ab270fe52617621732d5b2034c..9d56646cc857de85b47a1f9925a0e4bf89d8b524 100644 --- a/convlab/policy/ppo/tus_semantic_level_config.json +++ b/convlab/policy/ppo/tus_semantic_level_config.json @@ -1,6 +1,6 @@ { "model": { - "load_path": "convlab/policy/mle/experiments/experiment_2022-05-23-14-08-43/save/supervised", + "load_path": "convlab/policy/ppo/pretrained_models/mle", "use_pretrained_initialisation": false, "pretrained_load_path": "", "batchsz": 1000, @@ -35,9 +35,9 @@ "TUSPolicy": { "class_path": "convlab.policy.tus.unify.TUS.UserPolicy", "ini_params": { - "config": "convlab/policy/tus/unify/exp/all.json" + "config": "convlab/policy/tus/unify/exp/multiwoz.json" } } }, "usr_nlg": {} -} \ No newline at end of file +} diff --git a/convlab/policy/tus/unify/TUS.py b/convlab/policy/tus/unify/TUS.py index c380df3914d5d636c95b0e76081b361c26d79ff3..f4692e4819079691725d3952c652836fce951fd7 100644 --- a/convlab/policy/tus/unify/TUS.py +++ b/convlab/policy/tus/unify/TUS.py @@ -134,7 +134,7 @@ class UserActionPolicy(Policy): goal = Goal(goal_list) else: goal = ABUS_Goal(self.goal_gen) - self.raw_gaol = goal.domain_goals + self.raw_goal = goal.domain_goals goal_list = old_goal2list(goal.domain_goals) goal = Goal(goal_list) @@ -411,7 +411,8 @@ class UserPolicy(Policy): self.config = json.load(open(config)) else: self.config = config - self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}/multiwoz' + self.config["model_dir"] = f'{self.config["model_dir"]}_{dial_ids_order}' + print("model_dir", self.config['model_dir']) if not os.path.exists(self.config["model_dir"]): # os.mkdir(self.config["model_dir"]) model_downloader(os.path.dirname(self.config["model_dir"]),