Skip to content
Snippets Groups Projects
Commit e7da41a7 authored by Christian's avatar Christian
Browse files

DDPT model that stably optimizes from scratch or pre-trained with dataset. We...

DDPT model that stably optimizes from scratch or pre-trained with dataset. We can choose to only use some percentage of the data. Also fixed bug that recommend actions were not outputted and the seed could be overwritten incorrectly
parent 85f9129f
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,7 @@ def sampler(pid, queue, evt, sess, seed_range, goals): ...@@ -43,6 +43,7 @@ def sampler(pid, queue, evt, sess, seed_range, goals):
request = 0 request = 0
select = 0 select = 0
offer = 0 offer = 0
recommend = 0
task_success = {} task_success = {}
for i in range(40): for i in range(40):
...@@ -67,6 +68,8 @@ def sampler(pid, queue, evt, sess, seed_range, goals): ...@@ -67,6 +68,8 @@ def sampler(pid, queue, evt, sess, seed_range, goals):
select += 1 select += 1
if intent.lower() == 'offerbook': if intent.lower() == 'offerbook':
offer += 1 offer += 1
if intent.lower() == 'recommend':
recommend += 1
if session_over is True: if session_over is True:
success = sess.evaluator.task_success() success = sess.evaluator.task_success()
...@@ -81,7 +84,7 @@ def sampler(pid, queue, evt, sess, seed_range, goals): ...@@ -81,7 +84,7 @@ def sampler(pid, queue, evt, sess, seed_range, goals):
task_success[key].append(success_strict) task_success[key].append(success_strict)
buff.push(complete, success, success_strict, total_return_complete, total_return_success, turns, avg_actions / turns, buff.push(complete, success, success_strict, total_return_complete, total_return_success, turns, avg_actions / turns,
task_success, book, inform, request, select, offer) task_success, book, inform, request, select, offer, recommend)
# this is end of sampling all batchsz of items. # this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue # when sampling is over, push all buff data into queue
...@@ -136,7 +139,7 @@ def evaluate_distributed(sess, seed_range, process_num, goals): ...@@ -136,7 +139,7 @@ def evaluate_distributed(sess, seed_range, process_num, goals):
return np.average(batch.complete), np.average(batch.success), np.average(batch.success_strict), \ return np.average(batch.complete), np.average(batch.success), np.average(batch.success_strict), \
np.average(batch.total_return_success), np.average(batch.turns), np.average(batch.avg_actions), \ np.average(batch.total_return_success), np.average(batch.turns), np.average(batch.avg_actions), \
batch.task_success, np.average(batch.book_actions), np.average(batch.inform_actions), np.average(batch.request_actions), \ batch.task_success, np.average(batch.book_actions), np.average(batch.inform_actions), np.average(batch.request_actions), \
np.average(batch.select_actions), np.average(batch.offer_actions) np.average(batch.select_actions), np.average(batch.offer_actions), np.average(batch.recommend_actions)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -186,7 +186,7 @@ if __name__ == '__main__': ...@@ -186,7 +186,7 @@ if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--path", type=str, default='convlab/policy/gdpl/semantic_level_config.json', parser.add_argument("--path", type=str, default='convlab/policy/gdpl/semantic_level_config.json',
help="Load path for config file") help="Load path for config file")
parser.add_argument("--seed", type=int, default=0, parser.add_argument("--seed", type=int, default=None,
help="Seed for the policy parameter initialization") help="Seed for the policy parameter initialization")
parser.add_argument("--pretrain", action='store_true', help="whether to pretrain the reward estimator") parser.add_argument("--pretrain", action='store_true', help="whether to pretrain the reward estimator")
parser.add_argument("--mode", type=str, default='info', parser.add_argument("--mode", type=str, default='info',
...@@ -202,7 +202,7 @@ if __name__ == '__main__': ...@@ -202,7 +202,7 @@ if __name__ == '__main__':
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), mode) init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
args = [('model', 'seed', seed)] args = [('model', 'seed', seed)] if seed is not None else list()
environment_config = load_config_file(path) environment_config = load_config_file(path)
save_config(vars(parser.parse_args()), environment_config, config_save_path) save_config(vars(parser.parse_args()), environment_config, config_save_path)
......
...@@ -184,7 +184,7 @@ if __name__ == '__main__': ...@@ -184,7 +184,7 @@ if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--path", type=str, default='convlab/policy/pg/semantic_level_config.json', parser.add_argument("--path", type=str, default='convlab/policy/pg/semantic_level_config.json',
help="Load path for config file") help="Load path for config file")
parser.add_argument("--seed", type=int, default=0, parser.add_argument("--seed", type=int, default=None,
help="Seed for the policy parameter initialization") help="Seed for the policy parameter initialization")
parser.add_argument("--mode", type=str, default='info', parser.add_argument("--mode", type=str, default='info',
help="Set level for logger") help="Set level for logger")
...@@ -199,7 +199,7 @@ if __name__ == '__main__': ...@@ -199,7 +199,7 @@ if __name__ == '__main__':
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), mode) init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
args = [('model', 'seed', seed)] args = [('model', 'seed', seed)] if seed is not None else list()
environment_config = load_config_file(path) environment_config = load_config_file(path)
save_config(vars(parser.parse_args()), environment_config, config_save_path) save_config(vars(parser.parse_args()), environment_config, config_save_path)
......
...@@ -184,7 +184,7 @@ if __name__ == '__main__': ...@@ -184,7 +184,7 @@ if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json', parser.add_argument("--path", type=str, default='convlab/policy/ppo/semantic_level_config.json',
help="Load path for config file") help="Load path for config file")
parser.add_argument("--seed", type=int, default=0, parser.add_argument("--seed", type=int, default=None,
help="Seed for the policy parameter initialization") help="Seed for the policy parameter initialization")
parser.add_argument("--mode", type=str, default='info', parser.add_argument("--mode", type=str, default='info',
help="Set level for logger") help="Set level for logger")
...@@ -199,7 +199,7 @@ if __name__ == '__main__': ...@@ -199,7 +199,7 @@ if __name__ == '__main__':
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), mode) init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
args = [('model', 'seed', seed)] args = [('model', 'seed', seed)] if seed is not None else list()
environment_config = load_config_file(path) environment_config = load_config_file(path)
save_config(vars(parser.parse_args()), environment_config, config_save_path) save_config(vars(parser.parse_args()), environment_config, config_save_path)
......
...@@ -287,7 +287,7 @@ class Value(nn.Module): ...@@ -287,7 +287,7 @@ class Value(nn.Module):
Transition_evaluator = namedtuple('Transition_evaluator', Transition_evaluator = namedtuple('Transition_evaluator',
('complete', 'success', 'success_strict', 'total_return_complete', 'total_return_success', 'turns', ('complete', 'success', 'success_strict', 'total_return_complete', 'total_return_success', 'turns',
'avg_actions', 'task_success', 'book_actions', 'inform_actions', 'request_actions', 'select_actions', 'avg_actions', 'task_success', 'book_actions', 'inform_actions', 'request_actions', 'select_actions',
'offer_actions')) 'offer_actions', 'recommend_actions'))
class Memory_evaluator(object): class Memory_evaluator(object):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"batchsz": 64, "batchsz": 64,
"epoch": 40, "epoch": 40,
"gamma": 0.99, "gamma": 0.99,
"policy_lr": 0.00005, "policy_lr": 5e-06,
"supervised_lr": 1e-05, "supervised_lr": 1e-05,
"entropy_weight": 0.01, "entropy_weight": 0.01,
"value_lr": 0.0001, "value_lr": 0.0001,
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
"seed": 0, "seed": 0,
"lambda": 1, "lambda": 1,
"tau": 0.001, "tau": 0.001,
"policy_freq": 2, "policy_freq": 1,
"print_per_batch": 400, "print_per_batch": 400,
"c": 1.0, "c": 1.0,
"rho_bar": 1, "rho_bar": 1,
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
"dataset_name": "multiwoz21", "dataset_name": "multiwoz21",
"data_percentage": 1.0, "data_percentage": 1.0,
"dialogue_order": 0, "dialogue_order": 0,
"multiwoz_like": true, "multiwoz_like": false,
"regularization_weight": 0.0, "regularization_weight": 0.0,
"enc_input_dim": 128, "enc_input_dim": 128,
......
...@@ -226,7 +226,7 @@ if __name__ == '__main__': ...@@ -226,7 +226,7 @@ if __name__ == '__main__':
agent.imitating() agent.imitating()
logging.info(f"Epoch: {e}") logging.info(f"Epoch: {e}")
if e % args.eval_freq == 0 and e != 0: if e % args.eval_freq == 0:
precision, recall, f1 = agent.validate() precision, recall, f1 = agent.validate()
logging.info(f"Precision: {precision}") logging.info(f"Precision: {precision}")
......
...@@ -103,7 +103,7 @@ if __name__ == '__main__': ...@@ -103,7 +103,7 @@ if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json', parser.add_argument("--path", type=str, default='convlab/policy/vtrace_DPT/semantic_level_config.json',
help="Load path for config file") help="Load path for config file")
parser.add_argument("--seed", type=int, default=0, parser.add_argument("--seed", type=int, default=None,
help="Seed for the policy parameter initialization") help="Seed for the policy parameter initialization")
parser.add_argument("--mode", type=str, default='info', parser.add_argument("--mode", type=str, default='info',
help="Set level for logger") help="Set level for logger")
...@@ -118,7 +118,7 @@ if __name__ == '__main__': ...@@ -118,7 +118,7 @@ if __name__ == '__main__':
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \ logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), mode) init_logging(os.path.dirname(os.path.abspath(__file__)), mode)
args = [('model', 'seed', seed)] args = [('model', 'seed', seed)] if seed is not None else list()
environment_config = load_config_file(path) environment_config = load_config_file(path)
......
...@@ -171,20 +171,20 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do ...@@ -171,20 +171,20 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
if conf['model']['process_num'] == 1: if conf['model']['process_num'] == 1:
complete_rate, success_rate, success_rate_strict, avg_return, turns, \ complete_rate, success_rate, success_rate_strict, avg_return, turns, \
avg_actions, task_success, book_acts, inform_acts, request_acts, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \
select_acts, offer_acts = evaluate(sess, select_acts, offer_acts, recommend_acts = evaluate(sess,
num_dialogues=conf['model']['num_eval_dialogues'], num_dialogues=conf['model']['num_eval_dialogues'],
sys_semantic_to_usr=conf['model'][ sys_semantic_to_usr=conf['model'][
'sys_semantic_to_usr'], 'sys_semantic_to_usr'],
save_flag=save_eval, save_path=log_save_path, goals=goals) save_flag=save_eval, save_path=log_save_path, goals=goals)
total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
else: else:
complete_rate, success_rate, success_rate_strict, avg_return, turns, \ complete_rate, success_rate, success_rate_strict, avg_return, turns, \
avg_actions, task_success, book_acts, inform_acts, request_acts, \ avg_actions, task_success, book_acts, inform_acts, request_acts, \
select_acts, offer_acts = \ select_acts, offer_acts, recommend_acts = \
evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])), evaluate_distributed(sess, list(range(1000, 1000 + conf['model']['num_eval_dialogues'])),
conf['model']['process_num'], goals) conf['model']['process_num'], goals)
total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts total_acts = book_acts + inform_acts + request_acts + select_acts + offer_acts + recommend_acts
task_success_gathered = {} task_success_gathered = {}
for task_dict in task_success: for task_dict in task_success:
...@@ -199,7 +199,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do ...@@ -199,7 +199,7 @@ def eval_policy(conf, policy_sys, env, sess, save_eval, log_save_path, single_do
f"Average Return: {avg_return}, Turns: {turns}, Average Actions: {avg_actions}, " f"Average Return: {avg_return}, Turns: {turns}, Average Actions: {avg_actions}, "
f"Book Actions: {book_acts/total_acts}, Inform Actions: {inform_acts/total_acts}, " f"Book Actions: {book_acts/total_acts}, Inform Actions: {inform_acts/total_acts}, "
f"Request Actions: {request_acts/total_acts}, Select Actions: {select_acts/total_acts}, " f"Request Actions: {request_acts/total_acts}, Select Actions: {select_acts/total_acts}, "
f"Offer Actions: {offer_acts/total_acts}") f"Offer Actions: {offer_acts/total_acts}, Recommend Actions: {recommend_acts/total_acts}")
for key in task_success: for key in task_success:
logging.info( logging.info(
...@@ -303,7 +303,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False ...@@ -303,7 +303,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
task_success = {'All_user_sim': [], 'All_evaluator': [], "All_evaluator_strict": [], task_success = {'All_user_sim': [], 'All_evaluator': [], "All_evaluator_strict": [],
'total_return': [], 'turns': [], 'avg_actions': [], 'total_return': [], 'turns': [], 'avg_actions': [],
'total_booking_acts': [], 'total_inform_acts': [], 'total_request_acts': [], 'total_booking_acts': [], 'total_inform_acts': [], 'total_request_acts': [],
'total_select_acts': [], 'total_offer_acts': []} 'total_select_acts': [], 'total_offer_acts': [], 'total_recommend_acts': []}
dial_count = 0 dial_count = 0
for seed in range(1000, 1000 + num_dialogues): for seed in range(1000, 1000 + num_dialogues):
set_seed(seed) set_seed(seed)
...@@ -319,6 +319,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False ...@@ -319,6 +319,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
request = 0 request = 0
select = 0 select = 0
offer = 0 offer = 0
recommend = 0
# this 40 represents the max turn of dialogue # this 40 represents the max turn of dialogue
for i in range(40): for i in range(40):
sys_response, user_response, session_over, reward = sess.next_turn( sys_response, user_response, session_over, reward = sess.next_turn(
...@@ -341,6 +342,8 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False ...@@ -341,6 +342,8 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
select += 1 select += 1
if intent.lower() == 'offerbook': if intent.lower() == 'offerbook':
offer += 1 offer += 1
if intent.lower() == 'recommend':
recommend += 1
avg_actions += len(acts) avg_actions += len(acts)
turn_counter += 1 turn_counter += 1
turns += 1 turns += 1
...@@ -377,6 +380,8 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False ...@@ -377,6 +380,8 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
task_success['total_request_acts'].append(request) task_success['total_request_acts'].append(request)
task_success['total_select_acts'].append(select) task_success['total_select_acts'].append(select)
task_success['total_offer_acts'].append(offer) task_success['total_offer_acts'].append(offer)
task_success['total_offer_acts'].append(offer)
task_success['total_recommend_acts'].append(recommend)
# print(agent_sys.agent_saves) # print(agent_sys.agent_saves)
eval_save['Conversation {}'.format(str(dial_count))] = [ eval_save['Conversation {}'.format(str(dial_count))] = [
...@@ -397,7 +402,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False ...@@ -397,7 +402,7 @@ def evaluate(sess, num_dialogues=400, sys_semantic_to_usr=False, save_flag=False
np.average(task_success['turns']), np.average(task_success['avg_actions']), task_success, \ np.average(task_success['turns']), np.average(task_success['avg_actions']), task_success, \
np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \ np.average(task_success['total_booking_acts']), np.average(task_success['total_inform_acts']), \
np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \ np.average(task_success['total_request_acts']), np.average(task_success['total_select_acts']), \
np.average(task_success['total_offer_acts']) np.average(task_success['total_offer_acts']), np.average(task_success['total_recommend_acts'])
def model_downloader(download_dir, model_path): def model_downloader(download_dir, model_path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment