From b88d547fd3c4f271a2d80476cdeedcb4ef7b0bf2 Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Thu, 17 Mar 2022 14:52:56 +0100 Subject: [PATCH] small bug fix in gdpl --- convlab2/policy/gdpl/estimator.py | 2 +- convlab2/policy/gdpl/train.py | 3 ++- convlab2/policy/pg/pg.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/convlab2/policy/gdpl/estimator.py b/convlab2/policy/gdpl/estimator.py index 985ca33d..3012465d 100755 --- a/convlab2/policy/gdpl/estimator.py +++ b/convlab2/policy/gdpl/estimator.py @@ -40,7 +40,7 @@ class RewardEstimator(object): self.irl_iter = iter(self.data_train) if pretrain: self.data_train = manager.create_dataset_irl('train', cfg['batchsz']) - self.data_valid = manager.create_dataset_irl('valid', cfg['batchsz']) + self.data_valid = manager.create_dataset_irl('validation', cfg['batchsz']) self.data_test = manager.create_dataset_irl('test', cfg['batchsz']) self.irl_iter = iter(self.data_train) self.irl_iter_valid = iter(self.data_valid) diff --git a/convlab2/policy/gdpl/train.py b/convlab2/policy/gdpl/train.py index 1603618a..bb3d1e1e 100755 --- a/convlab2/policy/gdpl/train.py +++ b/convlab2/policy/gdpl/train.py @@ -185,6 +185,7 @@ if __name__ == '__main__': help="Load path for config file") parser.add_argument("--seed", type=int, default=0, help="Seed for the policy parameter initialization") + parser.add_argument("--pretrain", action='store_true', help="whether to pretrain the reward estimator") parser.add_argument("--mode", type=str, default='info', help="Set level for logger") parser.add_argument("--save_eval_dials", type=bool, default=False, @@ -209,7 +210,7 @@ if __name__ == '__main__': set_seed(seed) policy_sys = GDPL(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated']) - rewarder = RewardEstimator(policy_sys.vector, False) + rewarder = RewardEstimator(policy_sys.vector, parser.parse_args().pretrain) # Load model if conf['model']['use_pretrained_initialisation']: diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py index c2740d13..8b4088c8 100755 --- a/convlab2/policy/pg/pg.py +++ b/convlab2/policy/pg/pg.py @@ -153,7 +153,7 @@ class PG(Policy): for p in self.policy.parameters(): p.grad[p.grad != p.grad] = 0.0 # gradient clipping, for stability - torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10) + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 10) # self.lock.acquire() # retain lock to update weights self.policy_optim.step() # self.lock.release() # release lock -- GitLab