Skip to content
Snippets Groups Projects
Commit b88d547f authored by Christian's avatar Christian
Browse files

small bug fix in gdpl

parent 3873f23f
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,7 @@ class RewardEstimator(object):
self.irl_iter = iter(self.data_train)
if pretrain:
self.data_train = manager.create_dataset_irl('train', cfg['batchsz'])
self.data_valid = manager.create_dataset_irl('valid', cfg['batchsz'])
self.data_valid = manager.create_dataset_irl('validation', cfg['batchsz'])
self.data_test = manager.create_dataset_irl('test', cfg['batchsz'])
self.irl_iter = iter(self.data_train)
self.irl_iter_valid = iter(self.data_valid)
......
......@@ -185,6 +185,7 @@ if __name__ == '__main__':
help="Load path for config file")
parser.add_argument("--seed", type=int, default=0,
help="Seed for the policy parameter initialization")
parser.add_argument("--pretrain", action='store_true', help="whether to pretrain the reward estimator")
parser.add_argument("--mode", type=str, default='info',
help="Set level for logger")
parser.add_argument("--save_eval_dials", type=bool, default=False,
......@@ -209,7 +210,7 @@ if __name__ == '__main__':
set_seed(seed)
policy_sys = GDPL(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated'])
rewarder = RewardEstimator(policy_sys.vector, False)
rewarder = RewardEstimator(policy_sys.vector, parser.parse_args().pretrain)
# Load model
if conf['model']['use_pretrained_initialisation']:
......
......@@ -153,7 +153,7 @@ class PG(Policy):
for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability
torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 10)
# self.lock.acquire() # retain lock to update weights
self.policy_optim.step()
# self.lock.release() # release lock
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment