From 44746e90dc83c37de7c34d300e154fd643178c58 Mon Sep 17 00:00:00 2001
From: Christian Geishauser <45534723+ChrisGeishauser@users.noreply.github.com>
Date: Tue, 24 Jan 2023 17:14:58 +0100
Subject: [PATCH] changed ppo to use number of dialogues for training (#124)

* changed ppo to use number of dialogues for training

* Revert number of epochs

Co-authored-by: Christian <christian.geishauser@hhu.de>
Co-authored-by: Carel van Niekerk <niekerk@hhu.de>
---
 convlab/policy/ppo/README.md                  |  2 +-
 .../ppo/semanticGenTUS-RuleDST-PPOPolicy.json |  2 +-
 convlab/policy/ppo/semantic_level_config.json |  2 +-
 convlab/policy/ppo/setsumbt_config.json       |  2 +-
 convlab/policy/ppo/setsumbt_unc_config.json   |  2 +-
 convlab/policy/ppo/train.py                   | 22 +++++++++----------
 convlab/policy/ppo/trippy_config.json         |  2 +-
 .../policy/ppo/tus_semantic_level_config.json |  2 +-
 8 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md
index c762253c..9efb621b 100755
--- a/convlab/policy/ppo/README.md
+++ b/convlab/policy/ppo/README.md
@@ -21,7 +21,7 @@ One example for the environment-config is **semantic_level_config.json**, where
 - num_eval_dialogues: how many evaluation dialogues should be used
 - epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update
 - eval_frequency: after how many epochs perform an evaluation
-- batchsz: the number of training dialogues collected before doing an update
+- num_train_dialogues: the number of training dialogues collected before doing an update
 
 Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
 
diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
index 0e8774e2..7e170f6d 100644
--- a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
+++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json
@@ -3,7 +3,7 @@
 		"load_path": "convlab/policy/ppo/pretrained_models/mle",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 500,
+		"num_train_dialogues": 100,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json
index 04b0626a..0a16328a 100644
--- a/convlab/policy/ppo/semantic_level_config.json
+++ b/convlab/policy/ppo/semantic_level_config.json
@@ -3,7 +3,7 @@
 		"load_path": "",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
 		"epoch": 10,
 		"eval_frequency": 5,
diff --git a/convlab/policy/ppo/setsumbt_config.json b/convlab/policy/ppo/setsumbt_config.json
index b6a02adb..bf921100 100644
--- a/convlab/policy/ppo/setsumbt_config.json
+++ b/convlab/policy/ppo/setsumbt_config.json
@@ -3,7 +3,7 @@
 		"load_path": "",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/setsumbt_unc_config.json
index fafdb3fc..a80c04c9 100644
--- a/convlab/policy/ppo/setsumbt_unc_config.json
+++ b/convlab/policy/ppo/setsumbt_unc_config.json
@@ -3,7 +3,7 @@
 		"load_path": "",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py
index 703a5500..42fd9425 100755
--- a/convlab/policy/ppo/train.py
+++ b/convlab/policy/ppo/train.py
@@ -33,7 +33,7 @@ except RuntimeError:
     pass
 
 
-def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
+def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0):
 
     """
     This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
@@ -60,7 +60,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
 
     set_seed(train_seed)
 
-    while sampled_num < batchsz:
+    while sampled_traj_num < num_dialogues:
         # for each trajectory, we reset the env and get initial state
         s = env.reset()
         for t in range(traj_len):
@@ -108,7 +108,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
     evt.wait()
 
 
-def sample(env, policy, batchsz, process_num, seed):
+def sample(env, policy, num_train_dialogues, process_num, seed):
 
     """
     Given batchsz number of task, the batchsz will be splited equally to each processes
@@ -122,7 +122,7 @@ def sample(env, policy, batchsz, process_num, seed):
 
     # batchsz will be splitted into each process,
     # final batchsz maybe larger than batchsz parameters
-    process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
+    process_num_dialogues = np.ceil(num_train_dialogues / process_num).astype(np.int32)
     train_seeds = random.sample(range(0, 1000), process_num)
     # buffer to save all data
     queue = mp.Queue()
@@ -137,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed):
     evt = mp.Event()
     processes = []
     for i in range(process_num):
-        process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
+        process_args = (i, queue, evt, env, policy, process_num_dialogues, train_seeds[i])
         processes.append(mp.Process(target=sampler, args=process_args))
     for p in processes:
         # set the process as daemon, and it will be killed once the main process is stoped.
@@ -157,10 +157,10 @@ def sample(env, policy, batchsz, process_num, seed):
     return buff.get_batch()
 
 
-def update(env, policy, batchsz, epoch, process_num, seed=0):
+def update(env, policy, num_dialogues, epoch, process_num, seed=0):
 
     # sample data asynchronously
-    batch = sample(env, policy, batchsz, process_num, seed)
+    batch = sample(env, policy, num_dialogues, process_num, seed)
 
     # print(batch)
     # data in batch is : batch.state: ([1, s_dim], [1, s_dim]...)
@@ -224,7 +224,7 @@ if __name__ == '__main__':
         logging.info("Policy initialised from scratch")
 
     log_start_args(conf)
-    logging.info(f"New episodes per epoch: {conf['model']['batchsz']}")
+    logging.info(f"New episodes per epoch: {conf['model']['num_train_dialogues']}")
 
     env, sess = env_config(conf, policy_sys)
 
@@ -250,11 +250,11 @@ if __name__ == '__main__':
     for i in range(conf['model']['epoch']):
         idx = i + 1
         # print("Epoch :{}".format(str(idx)))
-        update(env, policy_sys, conf['model']['batchsz'], idx, conf['model']['process_num'], seed=seed)
+        update(env, policy_sys, conf['model']['num_train_dialogues'], idx, conf['model']['process_num'], seed=seed)
 
         if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
             time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
-            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
+            logging.info(f"Evaluating after Dialogues: {idx * conf['model']['num_train_dialogues']} - {time_now}" + '-' * 60)
 
             eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
 
@@ -264,7 +264,7 @@ if __name__ == '__main__':
                           eval_dict["avg_return"], save_path)
             policy_sys.save(save_path, "last")
             for key in eval_dict:
-                tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz'])
+                tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['num_train_dialogues'])
 
     logging.info("End of Training: " +
                  time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/trippy_config.json
index 41b1c362..fdb1b2f3 100644
--- a/convlab/policy/ppo/trippy_config.json
+++ b/convlab/policy/ppo/trippy_config.json
@@ -3,7 +3,7 @@
 		"load_path": "/path/to/model/checkpoint",
 		"pretrained_load_path": "",
 		"use_pretrained_initialisation": false,
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json
index 9d56646c..84dfff0c 100644
--- a/convlab/policy/ppo/tus_semantic_level_config.json
+++ b/convlab/policy/ppo/tus_semantic_level_config.json
@@ -3,7 +3,7 @@
 		"load_path": "convlab/policy/ppo/pretrained_models/mle",
 		"use_pretrained_initialisation": false,
 		"pretrained_load_path": "",
-		"batchsz": 1000,
+		"num_train_dialogues": 100,
 		"seed": 0,
 		"epoch": 50,
 		"eval_frequency": 5,
-- 
GitLab