diff --git a/convlab/policy/ppo/README.md b/convlab/policy/ppo/README.md index c762253ce4fe769bb2c540b39d39df713881a7f3..9efb621bcaa49cdcaba0446b90c5c80f77f77876 100755 --- a/convlab/policy/ppo/README.md +++ b/convlab/policy/ppo/README.md @@ -21,7 +21,7 @@ One example for the environment-config is **semantic_level_config.json**, where - num_eval_dialogues: how many evaluation dialogues should be used - epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update - eval_frequency: after how many epochs perform an evaluation -- batchsz: the number of training dialogues collected before doing an update +- num_train_dialogues: the number of training dialogues collected before doing an update Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc. diff --git a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json index 0e8774e20898c2855f368127f8e14e193ac2c21d..7e170f6ddb65798771bb5e497b6a9dbf7e6013f0 100644 --- a/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json +++ b/convlab/policy/ppo/semanticGenTUS-RuleDST-PPOPolicy.json @@ -3,7 +3,7 @@ "load_path": "convlab/policy/ppo/pretrained_models/mle", "pretrained_load_path": "", "use_pretrained_initialisation": false, - "batchsz": 500, + "num_train_dialogues": 100, "seed": 0, "epoch": 50, "eval_frequency": 5, diff --git a/convlab/policy/ppo/semantic_level_config.json b/convlab/policy/ppo/semantic_level_config.json index 04b0626a10bc8d48add16732df26a7cc00a35088..0a16328aacf1c31c45630dab312b96eaa8f333e7 100644 --- a/convlab/policy/ppo/semantic_level_config.json +++ b/convlab/policy/ppo/semantic_level_config.json @@ -3,7 +3,7 @@ "load_path": "", "use_pretrained_initialisation": false, "pretrained_load_path": "", - "batchsz": 1000, + "num_train_dialogues": 100, "seed": 0, "epoch": 10, "eval_frequency": 5, diff --git a/convlab/policy/ppo/setsumbt_config.json b/convlab/policy/ppo/setsumbt_config.json index b6a02adbf371bfea63e3e156a2d9e47f13456c78..bf9211006b6e2623016acfec18573768f73558fd 100644 --- a/convlab/policy/ppo/setsumbt_config.json +++ b/convlab/policy/ppo/setsumbt_config.json @@ -3,7 +3,7 @@ "load_path": "", "pretrained_load_path": "", "use_pretrained_initialisation": false, - "batchsz": 1000, + "num_train_dialogues": 100, "seed": 0, "epoch": 50, "eval_frequency": 5, diff --git a/convlab/policy/ppo/setsumbt_unc_config.json b/convlab/policy/ppo/setsumbt_unc_config.json index fafdb3fc9bd8f7fe09e3759d58a591cf964fb93b..a80c04c9656dbdb361de1ca74e3ca24db028b1cf 100644 --- a/convlab/policy/ppo/setsumbt_unc_config.json +++ b/convlab/policy/ppo/setsumbt_unc_config.json @@ -3,7 +3,7 @@ "load_path": "", "pretrained_load_path": "", "use_pretrained_initialisation": false, - "batchsz": 1000, + "num_train_dialogues": 100, "seed": 0, "epoch": 50, "eval_frequency": 5, diff --git a/convlab/policy/ppo/train.py b/convlab/policy/ppo/train.py index 703a55005b8c07578b85765a626d9871deebf26e..42fd9425d1416555da52073f15635cf7cfd99997 100755 --- a/convlab/policy/ppo/train.py +++ b/convlab/policy/ppo/train.py @@ -33,7 +33,7 @@ except RuntimeError: pass -def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): +def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0): """ This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple @@ -60,7 +60,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): set_seed(train_seed) - while sampled_num < batchsz: + while sampled_traj_num < num_dialogues: # for each trajectory, we reset the env and get initial state s = env.reset() for t in range(traj_len): @@ -108,7 +108,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): evt.wait() -def sample(env, policy, batchsz, process_num, seed): +def sample(env, policy, num_train_dialogues, process_num, seed): """ Given batchsz number of task, the batchsz will be splited equally to each processes @@ -122,7 +122,7 @@ def sample(env, policy, batchsz, process_num, seed): # batchsz will be splitted into each process, # final batchsz maybe larger than batchsz parameters - process_batchsz = np.ceil(batchsz / process_num).astype(np.int32) + process_num_dialogues = np.ceil(num_train_dialogues / process_num).astype(np.int32) train_seeds = random.sample(range(0, 1000), process_num) # buffer to save all data queue = mp.Queue() @@ -137,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed): evt = mp.Event() processes = [] for i in range(process_num): - process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i]) + process_args = (i, queue, evt, env, policy, process_num_dialogues, train_seeds[i]) processes.append(mp.Process(target=sampler, args=process_args)) for p in processes: # set the process as daemon, and it will be killed once the main process is stoped. @@ -157,10 +157,10 @@ def sample(env, policy, batchsz, process_num, seed): return buff.get_batch() -def update(env, policy, batchsz, epoch, process_num, seed=0): +def update(env, policy, num_dialogues, epoch, process_num, seed=0): # sample data asynchronously - batch = sample(env, policy, batchsz, process_num, seed) + batch = sample(env, policy, num_dialogues, process_num, seed) # print(batch) # data in batch is : batch.state: ([1, s_dim], [1, s_dim]...) @@ -224,7 +224,7 @@ if __name__ == '__main__': logging.info("Policy initialised from scratch") log_start_args(conf) - logging.info(f"New episodes per epoch: {conf['model']['batchsz']}") + logging.info(f"New episodes per epoch: {conf['model']['num_train_dialogues']}") env, sess = env_config(conf, policy_sys) @@ -250,11 +250,11 @@ if __name__ == '__main__': for i in range(conf['model']['epoch']): idx = i + 1 # print("Epoch :{}".format(str(idx))) - update(env, policy_sys, conf['model']['batchsz'], idx, conf['model']['process_num'], seed=seed) + update(env, policy_sys, conf['model']['num_train_dialogues'], idx, conf['model']['process_num'], seed=seed) if idx % conf['model']['eval_frequency'] == 0 and idx != 0: time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) - logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60) + logging.info(f"Evaluating after Dialogues: {idx * conf['model']['num_train_dialogues']} - {time_now}" + '-' * 60) eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path) @@ -264,7 +264,7 @@ if __name__ == '__main__': eval_dict["avg_return"], save_path) policy_sys.save(save_path, "last") for key in eval_dict: - tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz']) + tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['num_train_dialogues']) logging.info("End of Training: " + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) diff --git a/convlab/policy/ppo/trippy_config.json b/convlab/policy/ppo/trippy_config.json index 41b1c3623aca944312c6389e55e34d72422fb6e0..fdb1b2f3ba82e4e9d3c585319734affb0f4f1155 100644 --- a/convlab/policy/ppo/trippy_config.json +++ b/convlab/policy/ppo/trippy_config.json @@ -3,7 +3,7 @@ "load_path": "/path/to/model/checkpoint", "pretrained_load_path": "", "use_pretrained_initialisation": false, - "batchsz": 1000, + "num_train_dialogues": 100, "seed": 0, "epoch": 50, "eval_frequency": 5, diff --git a/convlab/policy/ppo/tus_semantic_level_config.json b/convlab/policy/ppo/tus_semantic_level_config.json index 9d56646cc857de85b47a1f9925a0e4bf89d8b524..84dfff0c337a9360cf376b72f6396f7e7ab91c48 100644 --- a/convlab/policy/ppo/tus_semantic_level_config.json +++ b/convlab/policy/ppo/tus_semantic_level_config.json @@ -3,7 +3,7 @@ "load_path": "convlab/policy/ppo/pretrained_models/mle", "use_pretrained_initialisation": false, "pretrained_load_path": "", - "batchsz": 1000, + "num_train_dialogues": 100, "seed": 0, "epoch": 50, "eval_frequency": 5,