From 1824b5412eccfb18c4bdf510615e3ad17d5375aa Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Thu, 17 Mar 2022 12:17:58 +0100 Subject: [PATCH] sampling is set to False in predict methods --- convlab2/policy/gdpl/gdpl.py | 2 +- convlab2/policy/pg/pg.py | 2 +- convlab2/policy/pg/train.py | 2 +- convlab2/policy/ppo/ppo.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/convlab2/policy/gdpl/gdpl.py b/convlab2/policy/gdpl/gdpl.py index 4479d00f..c245a485 100755 --- a/convlab2/policy/gdpl/gdpl.py +++ b/convlab2/policy/gdpl/gdpl.py @@ -64,7 +64,7 @@ class GDPL(Policy): s_vec = torch.Tensor(s) mask_vec = torch.Tensor(action_mask) a = self.policy.select_action( - s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu() + s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu() a_counter = 0 while a.sum() == 0: diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py index e5c53e21..c2740d13 100755 --- a/convlab2/policy/pg/pg.py +++ b/convlab2/policy/pg/pg.py @@ -62,7 +62,7 @@ class PG(Policy): s_vec = torch.Tensor(s) mask_vec = torch.Tensor(action_mask) a = self.policy.select_action( - s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu() + s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu() a_counter = 0 while a.sum() == 0: diff --git a/convlab2/policy/pg/train.py b/convlab2/policy/pg/train.py index c917088f..3abcd74b 100755 --- a/convlab2/policy/pg/train.py +++ b/convlab2/policy/pg/train.py @@ -179,7 +179,7 @@ if __name__ == '__main__': begin_time = datetime.now() parser = ArgumentParser() - parser.add_argument("--path", type=str, default='convlab2/policy/ppo/semantic_level_config.json', + parser.add_argument("--path", type=str, default='convlab2/policy/pg/semantic_level_config.json', help="Load path for config file") parser.add_argument("--seed", type=int, default=0, help="Seed for the policy parameter initialization") diff --git a/convlab2/policy/ppo/ppo.py b/convlab2/policy/ppo/ppo.py index 1de44bed..6635cf2e 100755 --- a/convlab2/policy/ppo/ppo.py +++ b/convlab2/policy/ppo/ppo.py @@ -73,7 +73,7 @@ class PPO(Policy): s_vec = torch.Tensor(s) mask_vec = torch.Tensor(action_mask) a = self.policy.select_action( - s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu() + s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu() a_counter = 0 while a.sum() == 0: -- GitLab