Skip to content
Snippets Groups Projects
Unverified Commit 182d847b authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Policy config refactoring (#125)


* Seperate test and train domains

* Add progress bars in ontology embedder

* Update custom_util.py

* Fix custom_util things I broke

* Github master

* Save dialogue ids in prediction file

* Fix bug in ontology enxtraction

* Return dialogue ids in predictions file and fix bugs

* Add setsumbt starting config loader

* Add script to extract golden labels from dataset to match model predictions

* Add more setsumbt configs

* Add option to use local files only in transformers package

* Update starting configurations for setsumbt

* Github master

* Update README.md

* Update README.md

* Update convlab/dialog_agent/agent.py

* Revert custom_util.py

* Update custom_util.py

* Commit unverified chnages :(:(:(:(

* Fix SetSUMBT bug resulting from new torch feature

* Setsumbt bug fixes

* Policy config refactor

* Policy config refactor

* small bug fix in memory with new config path

Co-authored-by: default avatarCarel van Niekerk <carel.niekerk@hhu.de>
Co-authored-by: default avatarMichael Heck <michael.heck@hhu.de>
Co-authored-by: default avatarChristian Geishauser <christian.geishauser@hhu.de>
parent 44746e90
Branches
No related tags found
No related merge requests found
Showing
with 17 additions and 13 deletions
...@@ -11,10 +11,10 @@ If you want to obtain a supervised model for pre-training, please have a look in ...@@ -11,10 +11,10 @@ If you want to obtain a supervised model for pre-training, please have a look in
Starting a RL training is as easy as executing Starting a RL training is as easy as executing
```sh ```sh
$ python train.py --path=your_environment_config --seed=SEED $ python train.py --config_name=your_config_name --seed=SEED
``` ```
One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance
- load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
- process_num: the number of processes to use during evaluation to speed it up - process_num: the number of processes to use during evaluation to speed it up
......
...@@ -22,7 +22,7 @@ class PPO(Policy): ...@@ -22,7 +22,7 @@ class PPO(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f:
cfg = json.load(f) cfg = json.load(f)
self.save_dir = os.path.join(os.path.dirname( self.save_dir = os.path.join(os.path.dirname(
os.path.abspath(__file__)), cfg['save_dir']) os.path.abspath(__file__)), cfg['save_dir'])
......
...@@ -182,8 +182,8 @@ if __name__ == '__main__': ...@@ -182,8 +182,8 @@ if __name__ == '__main__':
begin_time = datetime.now() begin_time = datetime.now()
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json', parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST',
help="Load path for config file") help="Name of the configuration")
parser.add_argument("--seed", type=int, default=None, parser.add_argument("--seed", type=int, default=None,
help="Seed for the policy parameter initialization") help="Seed for the policy parameter initialization")
parser.add_argument("--mode", type=str, default='info', parser.add_argument("--mode", type=str, default='info',
...@@ -191,7 +191,8 @@ if __name__ == '__main__': ...@@ -191,7 +191,8 @@ if __name__ == '__main__':
parser.add_argument("--save_eval_dials", type=bool, default=False, parser.add_argument("--save_eval_dials", type=bool, default=False,
help="Flag for saving dialogue_info during evaluation") help="Flag for saving dialogue_info during evaluation")
path = parser.parse_args().path path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs',
f'{parser.parse_args().config_name}.json')
seed = parser.parse_args().seed seed = parser.parse_args().seed
mode = parser.parse_args().mode mode = parser.parse_args().mode
save_eval = parser.parse_args().save_eval_dials save_eval = parser.parse_args().save_eval_dials
......
...@@ -31,10 +31,10 @@ We provide several supervised trained models on hugging-face to reproduce the re ...@@ -31,10 +31,10 @@ We provide several supervised trained models on hugging-face to reproduce the re
Starting a RL training is as easy as executing Starting a RL training is as easy as executing
```sh ```sh
$ python train.py --path=your_environment_config --seed=SEED $ python train.py --config_name=your_config_name --seed=SEED
``` ```
One example for the environment-config is **semantic_level_config.json**, where parameters for the training are specified, for instance One example for the environment-config is **RuleUser-Semantic-RuleDST**, where parameters for the training are specified, for instance
- load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl - load_path: provide a path to initialise the model with a pre-trained model, skip the ending .pol.mdl
- process_num: the number of processes to use during evaluation to speed it up - process_num: the number of processes to use during evaluation to speed it up
......
...@@ -17,7 +17,9 @@ class Memory: ...@@ -17,7 +17,9 @@ class Memory:
def __init__(self, seed=0): def __init__(self, seed=0):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: dir_name = os.path.dirname(os.path.abspath(__file__))
self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json')
with open(self.config_path, 'r') as f:
cfg = json.load(f) cfg = json.load(f)
self.batch_size = cfg.get('batchsz', 32) self.batch_size = cfg.get('batchsz', 32)
......
...@@ -101,8 +101,8 @@ if __name__ == '__main__': ...@@ -101,8 +101,8 @@ if __name__ == '__main__':
begin_time = datetime.now() begin_time = datetime.now()
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json', parser.add_argument("--config_name", type=str, default='RuleUser-Semantic-RuleDST',
help="Load path for config file") help="Name of the configuration")
parser.add_argument("--seed", type=int, default=None, parser.add_argument("--seed", type=int, default=None,
help="Seed for the policy parameter initialization") help="Seed for the policy parameter initialization")
parser.add_argument("--mode", type=str, default='info', parser.add_argument("--mode", type=str, default='info',
...@@ -110,7 +110,8 @@ if __name__ == '__main__': ...@@ -110,7 +110,8 @@ if __name__ == '__main__':
parser.add_argument("--save_eval_dials", type=bool, default=False, parser.add_argument("--save_eval_dials", type=bool, default=False,
help="Flag for saving dialogue_info during evaluation") help="Flag for saving dialogue_info during evaluation")
path = parser.parse_args().path path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs',
f"{parser.parse_args().config_name}.json")
seed = parser.parse_args().seed seed = parser.parse_args().seed
mode = parser.parse_args().mode mode = parser.parse_args().mode
save_eval = parser.parse_args().save_eval_dials save_eval = parser.parse_args().save_eval_dials
......
...@@ -26,7 +26,7 @@ class VTRACE(nn.Module, Policy): ...@@ -26,7 +26,7 @@ class VTRACE(nn.Module, Policy):
super(VTRACE, self).__init__() super(VTRACE, self).__init__()
dir_name = os.path.dirname(os.path.abspath(__file__)) dir_name = os.path.dirname(os.path.abspath(__file__))
self.config_path = os.path.join(dir_name, 'config.json') self.config_path = os.path.join(dir_name, 'configs', 'multiwoz21_dpt.json')
with open(self.config_path, 'r') as f: with open(self.config_path, 'r') as f:
cfg = json.load(f) cfg = json.load(f)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment