diff --git a/convlab2/policy/gdpl/gdpl.py b/convlab2/policy/gdpl/gdpl.py index 4479d00f37e0b1bc8e5398b10f198fbd43e08dce..c245a4853b83753725a4b3ed73b0b55d86c73c01 100755 --- a/convlab2/policy/gdpl/gdpl.py +++ b/convlab2/policy/gdpl/gdpl.py @@ -64,7 +64,7 @@ class GDPL(Policy): s_vec = torch.Tensor(s) mask_vec = torch.Tensor(action_mask) a = self.policy.select_action( - s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu() + s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu() a_counter = 0 while a.sum() == 0: diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py index e5c53e2144fe634d0a30de1768fde49b0208dbea..c2740d13bf8758d872f6310cb88510f173345815 100755 --- a/convlab2/policy/pg/pg.py +++ b/convlab2/policy/pg/pg.py @@ -62,7 +62,7 @@ class PG(Policy): s_vec = torch.Tensor(s) mask_vec = torch.Tensor(action_mask) a = self.policy.select_action( - s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu() + s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu() a_counter = 0 while a.sum() == 0: diff --git a/convlab2/policy/pg/train.py b/convlab2/policy/pg/train.py index c917088f5542a91fef2274ed6437e697e135efee..3abcd74b99fbaf46529ff07378bb052e1d8c4e97 100755 --- a/convlab2/policy/pg/train.py +++ b/convlab2/policy/pg/train.py @@ -179,7 +179,7 @@ if __name__ == '__main__': begin_time = datetime.now() parser = ArgumentParser() - parser.add_argument("--path", type=str, default='convlab2/policy/ppo/semantic_level_config.json', + parser.add_argument("--path", type=str, default='convlab2/policy/pg/semantic_level_config.json', help="Load path for config file") parser.add_argument("--seed", type=int, default=0, help="Seed for the policy parameter initialization") diff --git a/convlab2/policy/ppo/ppo.py b/convlab2/policy/ppo/ppo.py index 1de44bed41f6ce804ea4757b6ed3c984710591fc..6635cf2e624e8559efd9ae0cf4616b5312cfb9de 100755 --- a/convlab2/policy/ppo/ppo.py +++ b/convlab2/policy/ppo/ppo.py @@ -73,7 +73,7 @@ class PPO(Policy): s_vec = torch.Tensor(s) mask_vec = torch.Tensor(action_mask) a = self.policy.select_action( - s_vec.to(device=DEVICE), self.is_train, action_mask=mask_vec.to(device=DEVICE)).cpu() + s_vec.to(device=DEVICE), False, action_mask=mask_vec.to(device=DEVICE)).cpu() a_counter = 0 while a.sum() == 0: