From 77e9e288dacc7d5409c794d723b198341ef33f94 Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Tue, 24 Jan 2023 16:50:47 +0100
Subject: [PATCH] changed ppo to use number of dialogues for training

---
 convlab/policy/ppo/README.md                  |  2 +-
 .../ppo/semanticGenTUS-RuleDST-PPOPolicy.json |  4 ++--
 convlab/policy/ppo/semantic_level_config.json |  4 ++--
 convlab/policy/ppo/setsumbt_config.json       |  4 ++--
 convlab/policy/ppo/setsumbt_unc_config.json   |  4 ++--
 convlab/policy/ppo/train.py                   | 22 +++++++++----------
 convlab/policy/ppo/trippy_config.json         |  4 ++--
 .../policy/ppo/tus_semantic_level_config.json |  4 ++--
 8 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md
index c762253c..9efb621b 100755
--- a/convlab/policy/ppo/README.md
+++ b/convlab/policy/ppo/README.md
@@ -21,7 +21,7 @@ One example for the environment-config is **semantic_level_config.json**, where
 - num_eval_dialogues: how many evaluation dialogues should be used
 - epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update
 - eval_frequency: after how many epochs perform an evaluation
-- batchsz: the number of training dialogues collected before doing an update
+- num_train_dialogues: the number of training dialogues collected before doing an update
 
 Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
 
diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
index 0e8774e2..e0a4a4ec 100644
--- a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
+++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
@@ -3,9 +3,9 @@
 		"load_path": "convlab/policy/ppo/pretrained_models/mle",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 500,
+		"num_train_dialogues": 100,
 		"seed": 0,
-		"epoch": 50,
+		"epoch": 500,
 		"eval_frequency": 5,
 		"process_num": 1,
 		"num_eval_dialogues": 500,
diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json
index 04b0626a..fd627b15 100644
--- a/convlab/policy/ppo/semantic_level_config.json
+++ b/convlab/policy/ppo/semantic_level_config.json
@@ -3,9 +3,9 @@
 		"load_path": "",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
-		"epoch": 10,
+		"epoch": 100,
 		"eval_frequency": 5,
 		"process_num": 4,
 		"sys_semantic_to_usr": false,
diff --git a/convlab/policy/ppo/setsumbt_config.json b/convlab/policy/ppo/setsumbt_config.json
index b6a02adb..f3e7e442 100644
--- a/convlab/policy/ppo/setsumbt_config.json
+++ b/convlab/policy/ppo/setsumbt_config.json
@@ -3,9 +3,9 @@
 		"load_path": "",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
-		"epoch": 50,
+		"epoch": 500,
 		"eval_frequency": 5,
 		"process_num": 2,
 		"num_eval_dialogues": 500,
diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/setsumbt_unc_config.json
index fafdb3fc..df95e0fb 100644
--- a/convlab/policy/ppo/setsumbt_unc_config.json
+++ b/convlab/policy/ppo/setsumbt_unc_config.json
@@ -3,9 +3,9 @@
 		"load_path": "",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
-		"epoch": 50,
+		"epoch": 500,
 		"eval_frequency": 5,
 		"process_num": 2,
 		"num_eval_dialogues": 500,
diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py
index 703a5500..42fd9425 100755
--- a/convlab/policy/ppo/train.py
+++ b/convlab/policy/ppo/train.py
@@ -33,7 +33,7 @@ except RuntimeError:
     pass
 
 
-def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
+def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0):
 
     """
     This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
@@ -60,7 +60,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
 
     set_seed(train_seed)
 
-    while sampled_num < batchsz:
+    while sampled_traj_num < num_dialogues:
         # for each trajectory, we reset the env and get initial state
         s = env.reset()
         for t in range(traj_len):
@@ -108,7 +108,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
     evt.wait()
 
 
-def sample(env, policy, batchsz, process_num, seed):
+def sample(env, policy, num_train_dialogues, process_num, seed):
 
     """
     Given batchsz number of task, the batchsz will be splited equally to each processes
@@ -122,7 +122,7 @@ def sample(env, policy, batchsz, process_num, seed):
 
     # batchsz will be splitted into each process,
     # final batchsz maybe larger than batchsz parameters
-    process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
+    process_num_dialogues = np.ceil(num_train_dialogues / process_num).astype(np.int32)
     train_seeds = random.sample(range(0, 1000), process_num)
     # buffer to save all data
     queue = mp.Queue()
@@ -137,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed):
     evt = mp.Event()
     processes = []
     for i in range(process_num):
-        process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
+        process_args = (i, queue, evt, env, policy, process_num_dialogues, train_seeds[i])
         processes.append(mp.Process(target=sampler, args=process_args))
     for p in processes:
         # set the process as daemon, and it will be killed once the main process is stoped.
@@ -157,10 +157,10 @@ def sample(env, policy, batchsz, process_num, seed):
     return buff.get_batch()
 
 
-def update(env, policy, batchsz, epoch, process_num, seed=0):
+def update(env, policy, num_dialogues, epoch, process_num, seed=0):
 
     # sample data asynchronously
-    batch = sample(env, policy, batchsz, process_num, seed)
+    batch = sample(env, policy, num_dialogues, process_num, seed)
 
     # print(batch)
     # data in batch is : batch.state: ([1, s_dim], [1, s_dim]...)
@@ -224,7 +224,7 @@ if __name__ == '__main__':
         logging.info("Policy initialised from scratch")
 
     log_start_args(conf)
-    logging.info(f"New episodes per epoch: {conf['model']['batchsz']}")
+    logging.info(f"New episodes per epoch: {conf['model']['num_train_dialogues']}")
 
     env, sess = env_config(conf, policy_sys)
 
@@ -250,11 +250,11 @@ if __name__ == '__main__':
     for i in range(conf['model']['epoch']):
         idx = i + 1
         # print("Epoch :{}".format(str(idx)))
-        update(env, policy_sys, conf['model']['batchsz'], idx, conf['model']['process_num'], seed=seed)
+        update(env, policy_sys, conf['model']['num_train_dialogues'], idx, conf['model']['process_num'], seed=seed)
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
+            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['num_train_dialogues']} - {time_now}" + '-' * 60)
 
             eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
@@ -264,7 +264,7 @@ if __name__ == '__main__':
                           eval_dict["avg_return"], save_path)
             policy_sys.save(save_path, "last")
             for key in eval_dict:
-                tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz'])
+                tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['num_train_dialogues'])
 
     logging.info("End of Training: " +
                  time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/trippy_config.json
index 41b1c362..31430bfc 100644
--- a/convlab/policy/ppo/trippy_config.json
+++ b/convlab/policy/ppo/trippy_config.json
@@ -3,9 +3,9 @@
 		"load_path": "/path/to/model/checkpoint",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
-		"epoch": 50,
+		"epoch": 500,
 		"eval_frequency": 5,
 		"process_num": 2,
 		"num_eval_dialogues": 500,
diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json
index 9d56646c..14dacf9e 100644
--- a/convlab/policy/ppo/tus_semantic_level_config.json
+++ b/convlab/policy/ppo/tus_semantic_level_config.json
@@ -3,9 +3,9 @@
 		"load_path": "convlab/policy/ppo/pretrained_models/mle",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
-		"epoch": 50,
+		"epoch": 500,
 		"eval_frequency": 5,
 		"process_num": 4,
 		"sys_semantic_to_usr": false,
-- 
GitLab