From 182d847b00534a7ed32f35385e0d8707183ee1d3 Mon Sep 17 00:00:00 2001
From: Carel van Niekerk <40663106+carelvniekerk@users.noreply.github.com>
Date: Wed, 25 Jan 2023 10:51:33 +0100
Subject: [PATCH] Policy config refactoring (#125)

* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

Co-authored-by: Carel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: Michael Heck <michael.heck@hhu.de>
Co-authored-by: Christian Geishauser <christian.geishauser@hhu.de>
---
 convlab/policy/ppo/README.md                               | 4 ++--
 .../GenTUS-Semantic-RuleDST.json}                          | 0
 .../RuleUser-Semantic-RuleDST.json}                        | 0
 .../RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json}  | 0
 .../RuleUser-TemplateNLG-SetSUMBT.json}                    | 0
 .../RuleUser-TemplateNLG-TripPy.json}                      | 0
 .../TUS-Semantic-RuleDST.json}                             | 0
 .../policy/ppo/{config.json => configs/ppo_config.json}    | 0
 convlab/policy/ppo/ppo.py                                  | 2 +-
 convlab/policy/ppo/train.py                                | 7 ++++---
 convlab/policy/vtrace_DPT/README.md                        | 4 ++--
 .../RuleUser-Semantic-RuleDST.json}                        | 0
 .../{config.json => configs/multiwoz21_dpt.json}           | 0
 convlab/policy/vtrace_DPT/memory.py                        | 4 +++-
 convlab/policy/vtrace_DPT/train.py                         | 7 ++++---
 convlab/policy/vtrace_DPT/vtrace.py                        | 2 +-
 16 files changed, 17 insertions(+), 13 deletions(-)
 rename convlab/policy/ppo/{semanticGenTUS-RuleDST-PPOPolicy.json => configs/GenTUS-Semantic-RuleDST.json} (100%)
 rename convlab/policy/ppo/{semantic_level_config.json => configs/RuleUser-Semantic-RuleDST.json} (100%)
 rename convlab/policy/ppo/{setsumbt_unc_config.json => configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json} (100%)
 rename convlab/policy/ppo/{setsumbt_config.json => configs/RuleUser-TemplateNLG-SetSUMBT.json} (100%)
 rename convlab/policy/ppo/{trippy_config.json => configs/RuleUser-TemplateNLG-TripPy.json} (100%)
 rename convlab/policy/ppo/{tus_semantic_level_config.json => configs/TUS-Semantic-RuleDST.json} (100%)
 rename convlab/policy/ppo/{config.json => configs/ppo_config.json} (100%)
 rename convlab/policy/vtrace_DPT/{semantic_level_config.json => configs/RuleUser-Semantic-RuleDST.json} (100%)
 rename convlab/policy/vtrace_DPT/{config.json => configs/multiwoz21_dpt.json} (100%)

diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md
index 9efb621b..0e2f2906 100755
--- a/convlab/policy/ppo/README.md
+++ b/convlab/policy/ppo/README.md
@@ -11,10 +11,10 @@ If you want to obtain a supervised model for pre-training, please have a look in
 Starting a RL training is as easy as executing
 
 ```sh
-$ python train.py --path=your_environment_config --seed=SEED
+$ python train.py --config_name=your_config_name --seed=SEED
 ```
 
-One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance
+One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance
 
 - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
 - process_num: the number of processes to use during evaluation to speed it up
diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json
similarity index 100%
rename from convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
rename to convlab/policy/ppo/configs/GenTUS-Semantic-RuleDST.json
diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/configs/RuleUser-Semantic-RuleDST.json
similarity index 100%
rename from convlab/policy/ppo/semantic_level_config.json
rename to convlab/policy/ppo/configs/RuleUser-Semantic-RuleDST.json
diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json
similarity index 100%
rename from convlab/policy/ppo/setsumbt_unc_config.json
rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT-VectorUncertainty.json
diff --git a/convlab/policy/ppo/setsumbt_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json
similarity index 100%
rename from convlab/policy/ppo/setsumbt_config.json
rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-SetSUMBT.json
diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/configs/RuleUser-TemplateNLG-TripPy.json
similarity index 100%
rename from convlab/policy/ppo/trippy_config.json
rename to convlab/policy/ppo/configs/RuleUser-TemplateNLG-TripPy.json
diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/configs/TUS-Semantic-RuleDST.json
similarity index 100%
rename from convlab/policy/ppo/tus_semantic_level_config.json
rename to convlab/policy/ppo/configs/TUS-Semantic-RuleDST.json
diff --git a/convlab/policy/ppo/config.json b/convlab/policy/ppo/configs/ppo_config.json
similarity index 100%
rename from convlab/policy/ppo/config.json
rename to convlab/policy/ppo/configs/ppo_config.json
diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py
index 7ceefcd7..28fee71c 100755
--- a/convlab/policy/ppo/ppo.py
+++ b/convlab/policy/ppo/ppo.py
@@ -22,7 +22,7 @@ class PPO(Policy):
 
     def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None):
 
-        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
+        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f:
             cfg = json.load(f)
         self.save_dir = os.path.join(os.path.dirname(
             os.path.abspath(__file__)), cfg['save_dir'])
diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py
index 42fd9425..a2814df3 100755
--- a/convlab/policy/ppo/train.py
+++ b/convlab/policy/ppo/train.py
@@ -182,8 +182,8 @@ if __name__ == '__main__':
 
     begin_time = datetime.now()
     parser = ArgumentParser()
-    parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json',
-                        help="Load path for config file")
+    parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST',
+                        help="Name of the configuration")
     parser.add_argument("--seed", type=int, default=None,
                         help="Seed for the policy parameter initialization")
     parser.add_argument("--mode", type=str, default='info',
@@ -191,7 +191,8 @@ if __name__ == '__main__':
     parser.add_argument("--save_eval_dials", type=bool, default=False,
                         help="Flag for saving dialogue_info during evaluation")
 
-    path = parser.parse_args().path
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs',
+                        f'{parser.parse_args().config_name}.json')
     seed = parser.parse_args().seed
     mode = parser.parse_args().mode
     save_eval = parser.parse_args().save_eval_dials
diff --git a/convlab/policy/vtrace_DPT/README.md b/convlab/policy/vtrace_DPT/README.md
index 3b9e6f42..59798f57 100644
--- a/convlab/policy/vtrace_DPT/README.md
+++ b/convlab/policy/vtrace_DPT/README.md
@@ -31,10 +31,10 @@ We provide several supervised trained models on hugging-face to reproduce the re
 Starting a RL training is as easy as executing
 
 ```sh
-$ python train.py --path=your_environment_config --seed=SEED
+$ python train.py --config_name=your_config_name --seed=SEED
 ```
 
-One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance
+One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance
 
 - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
 - process_num: the number of processes to use during evaluation to speed it up
diff --git a/convlab/policy/vtrace_DPT/semantic_level_config.json b/convlab/policy/vtrace_DPT/configs/RuleUser-Semantic-RuleDST.json
similarity index 100%
rename from convlab/policy/vtrace_DPT/semantic_level_config.json
rename to convlab/policy/vtrace_DPT/configs/RuleUser-Semantic-RuleDST.json
diff --git a/convlab/policy/vtrace_DPT/config.json b/convlab/policy/vtrace_DPT/configs/multiwoz21_dpt.json
similarity index 100%
rename from convlab/policy/vtrace_DPT/config.json
rename to convlab/policy/vtrace_DPT/configs/multiwoz21_dpt.json
diff --git a/convlab/policy/vtrace_DPT/memory.py b/convlab/policy/vtrace_DPT/memory.py
index e9e13eb6..7b93d8a4 100644
--- a/convlab/policy/vtrace_DPT/memory.py
+++ b/convlab/policy/vtrace_DPT/memory.py
@@ -17,7 +17,9 @@ class Memory:
 
     def __init__(self, seed=0):
 
-        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
+        dir_name = os.path.dirname(os.path.abspath(__file__))
+        self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json')
+        with open(self.config_path, 'r') as f:
             cfg = json.load(f)
 
         self.batch_size = cfg.get('batchsz', 32)
diff --git a/convlab/policy/vtrace_DPT/train.py b/convlab/policy/vtrace_DPT/train.py
index f441da29..738fe81b 100644
--- a/convlab/policy/vtrace_DPT/train.py
+++ b/convlab/policy/vtrace_DPT/train.py
@@ -101,8 +101,8 @@ if __name__ == '__main__':
 
     begin_time = datetime.now()
     parser = ArgumentParser()
-    parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json',
-                        help="Load path for config file")
+    parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST',
+                        help="Name of the configuration")
     parser.add_argument("--seed", type=int, default=None,
                         help="Seed for the policy parameter initialization")
     parser.add_argument("--mode", type=str, default='info',
@@ -110,7 +110,8 @@ if __name__ == '__main__':
     parser.add_argument("--save_eval_dials", type=bool, default=False,
                         help="Flag for saving dialogue_info during evaluation")
 
-    path = parser.parse_args().path
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs',
+                        f"{parser.parse_args().config_name}.json")
     seed = parser.parse_args().seed
     mode = parser.parse_args().mode
     save_eval = parser.parse_args().save_eval_dials
diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py
index 625c214f..2918f4df 100644
--- a/convlab/policy/vtrace_DPT/vtrace.py
+++ b/convlab/policy/vtrace_DPT/vtrace.py
@@ -26,7 +26,7 @@ class VTRACE(nn.Module, Policy):
         super(VTRACE, self).__init__()
 
         dir_name = os.path.dirname(os.path.abspath(__file__))
-        self.config_path = os.path.join(dir_name, 'config.json')
+        self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json')
 
         with open(self.config_path, 'r') as f:
             cfg = json.load(f)
-- 
GitLab