From b88d547fd3c4f271a2d80476cdeedcb4ef7b0bf2 Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Thu, 17 Mar 2022 14:52:56 +0100
Subject: [PATCH] small bug fix in gdpl

---
 convlab2/policy/gdpl/estimator.py | 2 +-
 convlab2/policy/gdpl/train.py     | 3 ++-
 convlab2/policy/pg/pg.py          | 2 +-
 3 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/convlab2/policy/gdpl/estimator.py b/convlab2/policy/gdpl/estimator.py
index 985ca33d..3012465d 100755
--- a/convlab2/policy/gdpl/estimator.py
+++ b/convlab2/policy/gdpl/estimator.py
@@ -40,7 +40,7 @@ class RewardEstimator(object):
         self.irl_iter = iter(self.data_train)
         if pretrain:
             self.data_train = manager.create_dataset_irl('train', cfg['batchsz'])
-            self.data_valid = manager.create_dataset_irl('valid', cfg['batchsz'])
+            self.data_valid = manager.create_dataset_irl('validation', cfg['batchsz'])
             self.data_test = manager.create_dataset_irl('test', cfg['batchsz'])
             self.irl_iter = iter(self.data_train)
             self.irl_iter_valid = iter(self.data_valid)
diff --git a/convlab2/policy/gdpl/train.py b/convlab2/policy/gdpl/train.py
index 1603618a..bb3d1e1e 100755
--- a/convlab2/policy/gdpl/train.py
+++ b/convlab2/policy/gdpl/train.py
@@ -185,6 +185,7 @@ if __name__ == '__main__':
                         help="Load path for config file")
     parser.add_argument("--seed", type=int, default=0,
                         help="Seed for the policy parameter initialization")
+    parser.add_argument("--pretrain", action='store_true', help="whether to pretrain the reward estimator")
     parser.add_argument("--mode", type=str, default='info',
                         help="Set level for logger")
     parser.add_argument("--save_eval_dials", type=bool, default=False,
@@ -209,7 +210,7 @@ if __name__ == '__main__':
     set_seed(seed)
 
     policy_sys = GDPL(True, seed=conf['model']['seed'], vectorizer=conf['vectorizer_sys_activated'])
-    rewarder = RewardEstimator(policy_sys.vector, False)
+    rewarder = RewardEstimator(policy_sys.vector, parser.parse_args().pretrain)
 
     # Load model
     if conf['model']['use_pretrained_initialisation']:
diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py
index c2740d13..8b4088c8 100755
--- a/convlab2/policy/pg/pg.py
+++ b/convlab2/policy/pg/pg.py
@@ -153,7 +153,7 @@ class PG(Policy):
                 for p in self.policy.parameters():
                     p.grad[p.grad != p.grad] = 0.0
                 # gradient clipping, for stability
-                torch.nn.utils.clip_grad_norm(self.policy.parameters(), 10)
+                torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 10)
                 # self.lock.acquire() # retain lock to update weights
                 self.policy_optim.step()
                 # self.lock.release() # release lock
-- 
GitLab