diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md index 9efb621bcaa49cdcaba0446b90c5c80f77f77876..0e2f2906fad81113d0e420733be743aa8c18de70 100755 --- a/convlab/policy/ppo/README.md +++ b/convlab/policy/ppo/README.md @@ -11,10 +11,10 @@ If you want to obtain a supervised model for pre-training, please have a look in Starting a RL training is as easy as executing ```sh -$ python train.py --path=your_environment_config --seed=SEED +$ python train.py --config_name=your_config_name --seed=SEED ``` -One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance +One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl - process_num: the number of processes to use during evaluation to speed it up diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json rename to convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/configs/RuleUser-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/ppo/semantic_level_config.json rename to convlab/policy/ppo/configs/RuleUser-Semantic-RuleDST.json diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json similarity index 100% rename from convlab/policy/ppo/setsumbt_unc_config.json rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json diff --git a/convlab/policy/ppo/setsumbt_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json similarity index 100% rename from convlab/policy/ppo/setsumbt_config.json rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-TripPy.json similarity index 100% rename from convlab/policy/ppo/trippy_config.json rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-TripPy.json diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/configs/TUS-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/ppo/tus_semantic_level_config.json rename to convlab/policy/ppo/configs/TUS-Semantic-RuleDST.json diff --git a/convlab/policy/ppo/config.json b/convlab/policy/ppo/configs/ppo_config.json similarity index 100% rename from convlab/policy/ppo/config.json rename to convlab/policy/ppo/configs/ppo_config.json diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py index 7ceefcd77a4f07a1f42c2d4a368715969d3cdf53..28fee71c70c640319b29becc77bddfe8311f2767 100755 --- a/convlab/policy/ppo/ppo.py +++ b/convlab/policy/ppo/ppo.py @@ -22,7 +22,7 @@ class PPO(Policy): def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): - with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f: cfg = json.load(f) self.save_dir = os.path.join(os.path.dirname( os.path.abspath(__file__)), cfg['save_dir']) diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index 42fd9425d1416555da52073f15635cf7cfd99997..a2814df36a24642d053897f2e42e16cf56a3adea 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -182,8 +182,8 @@ if __name__ == '__main__': begin_time = datetime.now() parser = ArgumentParser() - parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json', - help="Load path for config file") + parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST', + help="Name of the configuration") parser.add_argument("--seed", type=int, default=None, help="Seed for the policy parameter initialization") parser.add_argument("--mode", type=str, default='info', @@ -191,7 +191,8 @@ if __name__ == '__main__': parser.add_argument("--save_eval_dials", type=bool, default=False, help="Flag for saving dialogue_info during evaluation") - path = parser.parse_args().path + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs', + f'{parser.parse_args().config_name}.json') seed = parser.parse_args().seed mode = parser.parse_args().mode save_eval = parser.parse_args().save_eval_dials diff --git a/convlab/policy/vtrace_DPT/README.md b/convlab/policy/vtrace_DPT/README.md index 3b9e6f4216f165aa99862d4beca554eee5430669..59798f578c0d62a7a2954d1977c86c90ce612701 100644 --- a/convlab/policy/vtrace_DPT/README.md +++ b/convlab/policy/vtrace_DPT/README.md @@ -31,10 +31,10 @@ We provide several supervised trained models on hugging-face to reproduce the re Starting a RL training is as easy as executing ```sh -$ python train.py --path=your_environment_config --seed=SEED +$ python train.py --config_name=your_config_name --seed=SEED ``` -One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance +One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl - process_num: the number of processes to use during evaluation to speed it up diff --git a/convlab/policy/vtrace_DPT/semantic_level_config.json b/convlab/policy/vtrace_DPT/configs/RuleUser-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/vtrace_DPT/semantic_level_config.json rename to convlab/policy/vtrace_DPT/configs/RuleUser-Semantic-RuleDST.json diff --git a/convlab/policy/vtrace_DPT/config.json b/convlab/policy/vtrace_DPT/configs/multiwoz21_dpt.json similarity index 100% rename from convlab/policy/vtrace_DPT/config.json rename to convlab/policy/vtrace_DPT/configs/multiwoz21_dpt.json diff --git a/convlab/policy/vtrace_DPT/memory.py b/convlab/policy/vtrace_DPT/memory.py index e9e13eb69bf68fa309ad552d62ff752c90ff230f..7b93d8a4860f321e33f1b94d9a9e935456fcdef5 100644 --- a/convlab/policy/vtrace_DPT/memory.py +++ b/convlab/policy/vtrace_DPT/memory.py @@ -17,7 +17,9 @@ class Memory: def __init__(self, seed=0): - with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + dir_name = os.path.dirname(os.path.abspath(__file__)) + self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json') + with open(self.config_path, 'r') as f: cfg = json.load(f) self.batch_size = cfg.get('batchsz', 32) diff --git a/convlab/policy/vtrace_DPT/train.py b/convlab/policy/vtrace_DPT/train.py index f441da29cf44fb366daf9f5afd3f2f21d1bd420d..738fe81bcf97b7232e0267036fe53c262df63982 100644 --- a/convlab/policy/vtrace_DPT/train.py +++ b/convlab/policy/vtrace_DPT/train.py @@ -101,8 +101,8 @@ if __name__ == '__main__': begin_time = datetime.now() parser = ArgumentParser() - parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json', - help="Load path for config file") + parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST', + help="Name of the configuration") parser.add_argument("--seed", type=int, default=None, help="Seed for the policy parameter initialization") parser.add_argument("--mode", type=str, default='info', @@ -110,7 +110,8 @@ if __name__ == '__main__': parser.add_argument("--save_eval_dials", type=bool, default=False, help="Flag for saving dialogue_info during evaluation") - path = parser.parse_args().path + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs', + f"{parser.parse_args().config_name}.json") seed = parser.parse_args().seed mode = parser.parse_args().mode save_eval = parser.parse_args().save_eval_dials diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 625c214f5a02bc010aaa8118fa9d78619e718b42..2918f4dfb019cef3cfe9e0d3981b70c9502700c8 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -26,7 +26,7 @@ class VTRACE(nn.Module, Policy): super(VTRACE, self).__init__() dir_name = os.path.dirname(os.path.abspath(__file__)) - self.config_path = os.path.join(dir_name, 'config.json') + self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json') with open(self.config_path, 'r') as f: cfg = json.load(f)