From 182d847b00534a7ed32f35385e0d8707183ee1d3 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com> Date: Wed, 25 Jan 2023 10:51:33 +0100 Subject: [PATCH] Policy config refactoring (#125) * Seperate test and train domains * Add progress bars in ontology embedder * Update custom_util.py * Fix custom_util things I broke * Github master * Save dialogue ids in prediction file * Fix bug in ontology enxtraction * Return dialogue ids in predictions file and fix bugs * Add setsumbt starting config loader * Add script to extract golden labels from dataset to match model predictions * Add more setsumbt configs * Add option to use local files only in transformers package * Update starting configurations for setsumbt * Github master * Update README.md * Update README.md * Update convlab/dialog_agent/agent.py * Revert custom_util.py * Update custom_util.py * Commit unverified chnages :(:(:(:( * Fix SetSUMBT bug resulting from new torch feature * Setsumbt bug fixes * Policy config refactor * Policy config refactor * small bug fix in memory with new config path Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de> Co-authored-by: Michael Heck <michael.heck@hhu.de> Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de> --- convlab/policy/ppo/README.md | 4 ++-- .../GenTUS-Semantic-RuleDST.json} | 0 .../RuleUser-Semantic-RuleDST.json} | 0 .../RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json} | 0 .../RuleUser-TemplateNLG-SetSUMBT.json} | 0 .../RuleUser-TemplateNLG-TripPy.json} | 0 .../TUS-Semantic-RuleDST.json} | 0 .../policy/ppo/{config.json => configs/ppo_config.json} | 0 convlab/policy/ppo/ppo.py | 2 +- convlab/policy/ppo/train.py | 7 ++++--- convlab/policy/vtrace_DPT/README.md | 4 ++-- .../RuleUser-Semantic-RuleDST.json} | 0 .../{config.json => configs/multiwoz21_dpt.json} | 0 convlab/policy/vtrace_DPT/memory.py | 4 +++- convlab/policy/vtrace_DPT/train.py | 7 ++++--- convlab/policy/vtrace_DPT/vtrace.py | 2 +- 16 files changed, 17 insertions(+), 13 deletions(-) rename convlab/policy/ppo/{semanticGenTUS-RuleDST-PPOPolicy.json => configs/GenTUS-Semantic-RuleDST.json} (100%) rename convlab/policy/ppo/{semantic_level_config.json => configs/RuleUser-Semantic-RuleDST.json} (100%) rename convlab/policy/ppo/{setsumbt_unc_config.json => configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json} (100%) rename convlab/policy/ppo/{setsumbt_config.json => configs/RuleUser-TemplateNLG-SetSUMBT.json} (100%) rename convlab/policy/ppo/{trippy_config.json => configs/RuleUser-TemplateNLG-TripPy.json} (100%) rename convlab/policy/ppo/{tus_semantic_level_config.json => configs/TUS-Semantic-RuleDST.json} (100%) rename convlab/policy/ppo/{config.json => configs/ppo_config.json} (100%) rename convlab/policy/vtrace_DPT/{semantic_level_config.json => configs/RuleUser-Semantic-RuleDST.json} (100%) rename convlab/policy/vtrace_DPT/{config.json => configs/multiwoz21_dpt.json} (100%) diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md index 9efb621b..0e2f2906 100755 --- a/convlab/policy/ppo/README.md +++ b/convlab/policy/ppo/README.md @@ -11,10 +11,10 @@ If you want to obtain a supervised model for pre-training, please have a look in Starting a RL training is as easy as executing ```sh -$ python train.py --path=your_environment_config --seed=SEED +$ python train.py --config_name=your_config_name --seed=SEED ``` -One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance +One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl - process_num: the number of processes to use during evaluation to speed it up diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json rename to convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/configs/RuleUser-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/ppo/semantic_level_config.json rename to convlab/policy/ppo/configs/RuleUser-Semantic-RuleDST.json diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json similarity index 100% rename from convlab/policy/ppo/setsumbt_unc_config.json rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json diff --git a/convlab/policy/ppo/setsumbt_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json similarity index 100% rename from convlab/policy/ppo/setsumbt_config.json rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-TripPy.json similarity index 100% rename from convlab/policy/ppo/trippy_config.json rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-TripPy.json diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/configs/TUS-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/ppo/tus_semantic_level_config.json rename to convlab/policy/ppo/configs/TUS-Semantic-RuleDST.json diff --git a/convlab/policy/ppo/config.json b/convlab/policy/ppo/configs/ppo_config.json similarity index 100% rename from convlab/policy/ppo/config.json rename to convlab/policy/ppo/configs/ppo_config.json diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py index 7ceefcd7..28fee71c 100755 --- a/convlab/policy/ppo/ppo.py +++ b/convlab/policy/ppo/ppo.py @@ -22,7 +22,7 @@ class PPO(Policy): def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): - with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f: cfg = json.load(f) self.save_dir = os.path.join(os.path.dirname( os.path.abspath(__file__)), cfg['save_dir']) diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index 42fd9425..a2814df3 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -182,8 +182,8 @@ if __name__ == '__main__': begin_time = datetime.now() parser = ArgumentParser() - parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json', - help="Load path for config file") + parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST', + help="Name of the configuration") parser.add_argument("--seed", type=int, default=None, help="Seed for the policy parameter initialization") parser.add_argument("--mode", type=str, default='info', @@ -191,7 +191,8 @@ if __name__ == '__main__': parser.add_argument("--save_eval_dials", type=bool, default=False, help="Flag for saving dialogue_info during evaluation") - path = parser.parse_args().path + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs', + f'{parser.parse_args().config_name}.json') seed = parser.parse_args().seed mode = parser.parse_args().mode save_eval = parser.parse_args().save_eval_dials diff --git a/convlab/policy/vtrace_DPT/README.md b/convlab/policy/vtrace_DPT/README.md index 3b9e6f42..59798f57 100644 --- a/convlab/policy/vtrace_DPT/README.md +++ b/convlab/policy/vtrace_DPT/README.md @@ -31,10 +31,10 @@ We provide several supervised trained models on hugging-face to reproduce the re Starting a RL training is as easy as executing ```sh -$ python train.py --path=your_environment_config --seed=SEED +$ python train.py --config_name=your_config_name --seed=SEED ``` -One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance +One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl - process_num: the number of processes to use during evaluation to speed it up diff --git a/convlab/policy/vtrace_DPT/semantic_level_config.json b/convlab/policy/vtrace_DPT/configs/RuleUser-Semantic-RuleDST.json similarity index 100% rename from convlab/policy/vtrace_DPT/semantic_level_config.json rename to convlab/policy/vtrace_DPT/configs/RuleUser-Semantic-RuleDST.json diff --git a/convlab/policy/vtrace_DPT/config.json b/convlab/policy/vtrace_DPT/configs/multiwoz21_dpt.json similarity index 100% rename from convlab/policy/vtrace_DPT/config.json rename to convlab/policy/vtrace_DPT/configs/multiwoz21_dpt.json diff --git a/convlab/policy/vtrace_DPT/memory.py b/convlab/policy/vtrace_DPT/memory.py index e9e13eb6..7b93d8a4 100644 --- a/convlab/policy/vtrace_DPT/memory.py +++ b/convlab/policy/vtrace_DPT/memory.py @@ -17,7 +17,9 @@ class Memory: def __init__(self, seed=0): - with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: + dir_name = os.path.dirname(os.path.abspath(__file__)) + self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json') + with open(self.config_path, 'r') as f: cfg = json.load(f) self.batch_size = cfg.get('batchsz', 32) diff --git a/convlab/policy/vtrace_DPT/train.py b/convlab/policy/vtrace_DPT/train.py index f441da29..738fe81b 100644 --- a/convlab/policy/vtrace_DPT/train.py +++ b/convlab/policy/vtrace_DPT/train.py @@ -101,8 +101,8 @@ if __name__ == '__main__': begin_time = datetime.now() parser = ArgumentParser() - parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json', - help="Load path for config file") + parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST', + help="Name of the configuration") parser.add_argument("--seed", type=int, default=None, help="Seed for the policy parameter initialization") parser.add_argument("--mode", type=str, default='info', @@ -110,7 +110,8 @@ if __name__ == '__main__': parser.add_argument("--save_eval_dials", type=bool, default=False, help="Flag for saving dialogue_info during evaluation") - path = parser.parse_args().path + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs', + f"{parser.parse_args().config_name}.json") seed = parser.parse_args().seed mode = parser.parse_args().mode save_eval = parser.parse_args().save_eval_dials diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 625c214f..2918f4df 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -26,7 +26,7 @@ class VTRACE(nn.Module, Policy): super(VTRACE, self).__init__() dir_name = os.path.dirname(os.path.abspath(__file__)) - self.config_path = os.path.join(dir_name, 'config.json') + self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json') with open(self.config_path, 'r') as f: cfg = json.load(f) -- GitLab