Skip to content
Snippets Groups Projects
Unverified Commit 44746e90 authored by Christian Geishauser's avatar Christian Geishauser Committed by GitHub
Browse files

changed ppo to use number of dialogues for training (#124)


* changed ppo to use number of dialogues for training

* Revert number of epochs

Co-authored-by: default avatarChristian <christian.geishauser@hhu.de>
Co-authored-by: default avatarCarel van Niekerk <niekerk@hhu.de>
parent 86ed7ad3
Branches
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@ One example for the environment-config is **semantic_level_config.json**, where
- num_eval_dialogues: how many evaluation dialogues should be used
- epoch: how many training epochs to run. One epoch consists of collecting dialogues + performing an update
- eval_frequency: after how many epochs perform an evaluation
- batchsz: the number of training dialogues collected before doing an update
- num_train_dialogues: the number of training dialogues collected before doing an update
Moreover, you can specify the full dialogue pipeline here, such as the user policy, NLU for system and user, etc.
......
......@@ -3,7 +3,7 @@
"load_path": "convlab/policy/ppo/pretrained_models/mle",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 500,
"num_train_dialogues": 100,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
......
......@@ -3,7 +3,7 @@
"load_path": "",
"use_pretrained_initialisation": false,
"pretrained_load_path": "",
"batchsz": 1000,
"num_train_dialogues": 100,
"seed": 0,
"epoch": 10,
"eval_frequency": 5,
......
......@@ -3,7 +3,7 @@
"load_path": "",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"num_train_dialogues": 100,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
......
......@@ -3,7 +3,7 @@
"load_path": "",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"num_train_dialogues": 100,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
......
......@@ -33,7 +33,7 @@ except RuntimeError:
pass
def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
def sampler(pid, queue, evt, env, policy, num_dialogues, train_seed=0):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
......@@ -60,7 +60,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
set_seed(train_seed)
while sampled_num < batchsz:
while sampled_traj_num < num_dialogues:
# for each trajectory, we reset the env and get initial state
s = env.reset()
for t in range(traj_len):
......@@ -108,7 +108,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
evt.wait()
def sample(env, policy, batchsz, process_num, seed):
def sample(env, policy, num_train_dialogues, process_num, seed):
"""
Given batchsz number of task, the batchsz will be splited equally to each processes
......@@ -122,7 +122,7 @@ def sample(env, policy, batchsz, process_num, seed):
# batchsz will be splitted into each process,
# final batchsz maybe larger than batchsz parameters
process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
process_num_dialogues = np.ceil(num_train_dialogues / process_num).astype(np.int32)
train_seeds = random.sample(range(0, 1000), process_num)
# buffer to save all data
queue = mp.Queue()
......@@ -137,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed):
evt = mp.Event()
processes = []
for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
process_args = (i, queue, evt, env, policy, process_num_dialogues, train_seeds[i])
processes.append(mp.Process(target=sampler, args=process_args))
for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped.
......@@ -157,10 +157,10 @@ def sample(env, policy, batchsz, process_num, seed):
return buff.get_batch()
def update(env, policy, batchsz, epoch, process_num, seed=0):
def update(env, policy, num_dialogues, epoch, process_num, seed=0):
# sample data asynchronously
batch = sample(env, policy, batchsz, process_num, seed)
batch = sample(env, policy, num_dialogues, process_num, seed)
# print(batch)
# data in batch is : batch.state: ([1, s_dim], [1, s_dim]...)
......@@ -224,7 +224,7 @@ if __name__ == '__main__':
logging.info("Policy initialised from scratch")
log_start_args(conf)
logging.info(f"New episodes per epoch: {conf['model']['batchsz']}")
logging.info(f"New episodes per epoch: {conf['model']['num_train_dialogues']}")
env, sess = env_config(conf, policy_sys)
......@@ -250,11 +250,11 @@ if __name__ == '__main__':
for i in range(conf['model']['epoch']):
idx = i + 1
# print("Epoch :{}".format(str(idx)))
update(env, policy_sys, conf['model']['batchsz'], idx, conf['model']['process_num'], seed=seed)
update(env, policy_sys, conf['model']['num_train_dialogues'], idx, conf['model']['process_num'], seed=seed)
if idx % conf['model']['eval_frequency'] == 0 and idx != 0:
time_now = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
logging.info(f"Evaluating after Dialogues: {idx * conf['model']['batchsz']} - {time_now}" + '-' * 60)
logging.info(f"Evaluating after Dialogues: {idx * conf['model']['num_train_dialogues']} - {time_now}" + '-' * 60)
eval_dict = eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path)
......@@ -264,7 +264,7 @@ if __name__ == '__main__':
eval_dict["avg_return"], save_path)
policy_sys.save(save_path, "last")
for key in eval_dict:
tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['batchsz'])
tb_writer.add_scalar(key, eval_dict[key], idx * conf['model']['num_train_dialogues'])
logging.info("End of Training: " +
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
......
......@@ -3,7 +3,7 @@
"load_path": "/path/to/model/checkpoint",
"pretrained_load_path": "",
"use_pretrained_initialisation": false,
"batchsz": 1000,
"num_train_dialogues": 100,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
......
......@@ -3,7 +3,7 @@
"load_path": "convlab/policy/ppo/pretrained_models/mle",
"use_pretrained_initialisation": false,
"pretrained_load_path": "",
"batchsz": 1000,
"num_train_dialogues": 100,
"seed": 0,
"epoch": 50,
"eval_frequency": 5,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment