From 1824b5412eccfb18c4bdf510615e3ad17d5375aa Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Thu, 17 Mar 2022 12:17:58 +0100
Subject: [PATCH] sampling is set to False in predict methods

---
 convlab2/policy/gdpl/gdpl.py | 2 +-
 convlab2/policy/pg/pg.py     | 2 +-
 convlab2/policy/pg/train.py  | 2 +-
 convlab2/policy/ppo/ppo.py   | 2 +-
 4 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/convlab2/policy/gdpl/gdpl.py b/convlab2/policy/gdpl/gdpl.py
index 4479d00f..c245a485 100755
--- a/convlab2/policy/gdpl/gdpl.py
+++ b/convlab2/policy/gdpl/gdpl.py
@@ -64,7 +64,7 @@ class GDPL(Policy):
         s_vec = torch.Tensor(s)
         mask_vec = torch.Tensor(action_mask)
         a = self.policy.select_action(
-            s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu()
+            s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu()
 
         a_counter = 0
         while a.sum() == 0:
diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py
index e5c53e21..c2740d13 100755
--- a/convlab2/policy/pg/pg.py
+++ b/convlab2/policy/pg/pg.py
@@ -62,7 +62,7 @@ class PG(Policy):
         s_vec = torch.Tensor(s)
         mask_vec = torch.Tensor(action_mask)
         a = self.policy.select_action(
-            s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu()
+            s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu()
 
         a_counter = 0
         while a.sum() == 0:
diff --git a/convlab2/policy/pg/train.py b/convlab2/policy/pg/train.py
index c917088f..3abcd74b 100755
--- a/convlab2/policy/pg/train.py
+++ b/convlab2/policy/pg/train.py
@@ -179,7 +179,7 @@ if __name__ == '__main__':
 
     begin_time = datetime.now()
     parser = ArgumentParser()
-    parser.add_argument("--path", type=str, default='convlab2/policy/ppo/semantic_level_config.json',
+    parser.add_argument("--path", type=str, default='convlab2/policy/pg/semantic_level_config.json',
                         help="Load path for config file")
     parser.add_argument("--seed", type=int, default=0,
                         help="Seed for the policy parameter initialization")
diff --git a/convlab2/policy/ppo/ppo.py b/convlab2/policy/ppo/ppo.py
index 1de44bed..6635cf2e 100755
--- a/convlab2/policy/ppo/ppo.py
+++ b/convlab2/policy/ppo/ppo.py
@@ -73,7 +73,7 @@ class PPO(Policy):
         s_vec = torch.Tensor(s)
         mask_vec = torch.Tensor(action_mask)
         a = self.policy.select_action(
-            s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu()
+            s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu()
 
         a_counter = 0
         while a.sum() == 0:
-- 
GitLab