From e368deeb3d405caf19236fb768360a6517a24fcd Mon Sep 17 00:00:00 2001 From: Carrey Wang <hrwang@se.cuhk.edu.hk> Date: Thu, 22 Oct 2020 14:19:05 +0800 Subject: [PATCH] Add warmup for DQN and fix minor bugs (#150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Initial commit * Update README.md * sync with commit aa1af0ee81ba591d1cf3c222c9d71963ed1dca98 * add gitignore * update tutorial * update mdrg, not use dbPointer * update mdrg, not use dbPointer * update mdrg, download before use dbPointer * update analyzer * update README * update tutorial * Fix dbquery when matching name * move change to dev branch * move dbquery change from master to dev branch * disable travis for now * do nothing in travis * do nothing in travis for now * not deploy now * add docs * update .travis.yml * update .travis.yml * update .travis.yml * update rst files * add alias center for centre in dbquery * fix the policy training bug * fix the bug of nan gradient * add cross-lingual dst data * Update README.md * Update README.md * Add CrossWoz Web support and some minor bug fix (#19) * Initial commit * first commit * add build * add build * add build * add recommend * add crosswoz config in deploy * add crosswoz at html * debug chinese vision * fix system bug according to convlab2 * master change * modify .gitignore * delete svm_camrest_usr.pickle Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local> Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local> * modify xdst data name * Translation train on MultiWOZ (Chinese) nad CrossWOZ (English) of SUMBT (#17) * multiwoz_zh * crosswoz_en * translation train * test translation train * update evaluation code * update evaluation code for crosswoz * evaluate human val set * update readme * evaluate machine val * extract all ontology, bad result * update evalutate * update evalutation result on crosswoz-en * updata xdst baseline * Update README.md * fix allennlp==0.9.0 * Update README.md * modify build message function for goal generation * Fix goal generator and dbquery for multiwoz (#32) * move dbquery change from master to dev branch * add alias center for centre in dbquery * replace attraction type 'mutliple sports' to 'multiple sports', involving only one entity * add depart and destination constraints for searching db (ignore=False), modify goal generator to draw the values of these two slots from database * fix bug (#35) * multiwoz_zh * crosswoz_en * translation train * test translation train * update evaluation code * update evaluation code for crosswoz * evaluate human val set * update readme * evaluate machine val * extract all ontology, bad result * update evalutate * update evalutation result on crosswoz-en * fix bug #34 * revert changes * update demo video link * Update README.md * some changes in #36 (#37) * multiwoz_zh * crosswoz_en * translation train * test translation train * update evaluation code * update evaluation code for crosswoz * evaluate human val set * update readme * evaluate machine val * extract all ontology, bad result * update evalutate * update evalutation result on crosswoz-en * fix bug #34 * revert changes * revert changes * some changes of #36 * fix analyzer example.py * dst/evaluate.py: Use utf-8 encoding * use transformers library to automate model caching * Update README.md * cut sentences that exceed 512 tokens in jointBERT * Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy. * improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act * fix nlu max len * update travis * Update run_agent.py * Create README.md * Update README.md * modify human_eval README * fix sclstm crosswoz import issues * update travis.yml * try to fix deploy * Update README.md * Update README.md * Improve agenda policy (#52) * cut sentences that exceed 512 tokens in jointBERT * Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy. * improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act * improve goal sample strategy * Update README.md #53 * Update README.md (#57) * Improve agenda policy (#60) * cut sentences that exceed 512 tokens in jointBERT * Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy. * improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act * improve goal sample strategy * fix self.cur_domain=None when system offer book * Improve agenda policy (#62) * cut sentences that exceed 512 tokens in jointBERT * Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy. * improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act * improve goal sample strategy * fix self.cur_domain=None when system offer book * fix agenda for 0 choice * fix sequicityy * fix sequicityy * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * update setup.py:add tokenizers requirement * fix typo * update user nlg template * Update README.md * remove fail book in multiwoz goal generator * fix taxi dontcare problem * can manually set user goal in agenda now * test goal overlap between generator and trainset * change default taxi depart and destination from address to name/'the hotel/restaurant' * change initiative from 4 to randint(2,4) * agenda pop more da when only answer dontcare * add 'the same area/pricerange/people/day' in agenda with 0.3 probability * remove unnecessary thank you * add domain for postcode and Phone in user templateNLG * add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy * add template for interent-no, parking-no in templatenlg * update Evaluator: check whether final goal satisfies constraints * update evaluator: check booked entity * output goal analysis to file * update goal analysis * update * Update analyzer.py * Fix simulator (#83) * remove fail book in multiwoz goal generator * fix taxi dontcare problem * can manually set user goal in agenda now * test goal overlap between generator and trainset * change default taxi depart and destination from address to name/'the hotel/restaurant' * change initiative from 4 to randint(2,4) * agenda pop more da when only answer dontcare * add 'the same area/pricerange/people/day' in agenda with 0.3 probability * remove unnecessary thank you * add domain for postcode and Phone in user templateNLG * add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy * add template for interent-no, parking-no in templatenlg * remove police and hospital domain in goal generator * update multiwoz evaluator: adding 'internet/parking-none, 24:** to valid value * fix nlg template (#88) * add new_goal_model without police and hospital domain (#89) * Normalize string comparisons in multiwoz template nlg to be case insensitive (#87) * normalize template nlg keys to be lower case * fix slot comparison in multiwoz nlg to be case insensitive * use value_lower instead of calling .lower() on each comparison * Add police n hospital (#95) * add back police and hospital goal * update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result * update dbquery and session (#99) * update dbquery: ? matches all; fix bug in init_session * update multiwoz_eval, check Ref of booked * filter domain in final_goal_analyze Co-authored-by: newRuntieException <wdz15@mails.tsinghua.edu.cn> * Add dockerfile (#98) * fix nlg template * add dockerfile * include missing packages at setup.py (#102) * multiwoz dbquery doesnt require mutable constraints (#106) * Add police n hospital (#107) * add back police and hospital goal * update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result * update user templatenlg * add test set example for dstc9 (multiwoz_zh, crosswoz_en) (#108) * Add dockerfile (#110) * fix nlg template * add dockerfile * add package for dockerfile * update versions * Update README.md * Update versions in setup (#111) * move dbquery change from master to dev branch * add alias center for centre in dbquery * fix sequicityy * update versions Co-authored-by: zqwerty <zhuq96@hotmail.com> Co-authored-by: zhuqi <zqwerty@users.noreply.github.com> * Update README.md * Update README.md * Update README.md * fix system nlg template bug (#117) * add 'book' in DST evaluation. (#85) * Maintenance (#119) * add test set example for dstc9 (multiwoz_zh, crosswoz_en) * update new_goal_model.pkl * update crosswoz auto_sys_template_nlg * add postcode as special case for NLU tokenization * dstc9 eval * dstc9 xldst evaluation * Nlg template fix (#121) * fix nlg template * fix user nlg template issue * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * dstc9 xldst evaluation (#122) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * revise xldst evaluation (#124) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * Update dst.py * update precision, recall, f1 calculation * minor change * fix policy evaluation * Nlg template fix (#127) * fix nlg template * fix user nlg template issue * fix system NLG template * nlu update and bugfix (#118) * jointBERT_new avaliable && fix milu dataset_reader && fix jointBERT/tag2id * remove jointBERT_new * update milu/multiwoz/nlu.py model_file path * add metrics in XLDST evaluation (#126) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * update precision, recall, f1 calculation * minor change * add input reqt vals in human eval (#128) * Maintenance (#129) * add test set example for dstc9 (multiwoz_zh, crosswoz_en) * update new_goal_model.pkl * update crosswoz auto_sys_template_nlg * add postcode as special case for NLU tokenization * fix lower case for int value in nlg.py * Human (#131) * change task config * add final goal logging * encapsule PipelineAgent internal state interface for return and replacement * Maintenance (#132) * add test set example for dstc9 (multiwoz_zh, crosswoz_en) * update new_goal_model.pkl * update crosswoz auto_sys_template_nlg * add postcode as special case for NLU tokenization * fix lower case for int value in nlg.py * fix empty user utterance problem in multiwoz simulator, issue #130 * remove debug output * fix a database typo * Maintenance (#134) * add test set example for dstc9 (multiwoz_zh, crosswoz_en) * update new_goal_model.pkl * update crosswoz auto_sys_template_nlg * add postcode as special case for NLU tokenization * fix lower case for int value in nlg.py * fix empty user utterance problem in multiwoz simulator, issue #130 * remove debug output * fix goal generator for police domain message * fix a minor typo in crosswoz database (#133) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * update precision, recall, f1 calculation * minor change * fix a database typo * use selectedResults for missing name * remove low performance baselines (#136) * Human2 (#137) * change task config * add final goal logging * encapsule PipelineAgent internal state interface for return and replacement * fix bug associted with the issue of strange user input * Fix a bug in TRADE CrossWOZ training (#138) * add 'book' in DST evaluation. * Fix TRADE crosswoz training evaluation bug Co-authored-by: zheng <zheng@zhangzheng-PC.lan> * Maintenance (#140) * add test set example for dstc9 (multiwoz_zh, crosswoz_en) * update new_goal_model.pkl * update crosswoz auto_sys_template_nlg * add postcode as special case for NLU tokenization * fix lower case for int value in nlg.py * fix empty user utterance problem in multiwoz simulator, issue #130 * remove debug output * fix goal generator for police domain message * update template NLG * Add note for deploy web service (#139) * add 'book' in DST evaluation. * Fix TRADE crosswoz training evaluation bug * Add note on deploy Co-authored-by: zheng <zheng@zhangzheng-PC.lan> * add value unification * fix XLDST evaluation (#141) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * update precision, recall, f1 calculation * minor change * fix a database typo * use selectedResults for missing name * add value unification * fix user Nlg template (#142) * fix system nlg template bug * fix user nlg issue * fix white character issue #144 * deal with white charater in XLDST evaluation (#145) * update sumbt translation train result with evaluation mode set * update extract values * automatically download sumbt model * dstc9 eval * dstc9 xldst evaluation * modify example * add .gitignore * remove precision, recall, f1 * release 250 test data * revise evaluation * fix file submission example * update precision, recall, f1 calculation * minor change * fix a database typo * use selectedResults for missing name * add value unification * fix white character issue #144 * DQN (#113) * implemented script to extract all the statistics for all dialogue_act in data * changed script for actions be compatible to sys_da_voc.txt actions * multiwoz vector now supports composite actions * implemented ReplayMemory and EpsilongGreedyPolicy * implemented a basic version of dqn * included some comments * Add DQN Test and Change file structure (#146) * Initial commit * first commit * add build * add build * add build * add recommend * add crosswoz config in deploy * add crosswoz at html * debug chinese vision * fix system bug according to convlab2 * master change * modify .gitignore * delete svm_camrest_usr.pickle * Update server.py * add test for DQN * change server Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn> Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local> Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local> Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com> Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk> * update eval * dump dst eval results * make value lower * add progress bar * fix bug in last commit * Update policy_agenda_multiwoz.py * remove unnecessary mapping (#147) * udpate dstc9 eval * make value lower * add warm up for dqn and fix bugs * rm unrelated files Co-authored-by: zhuqi <zqwerty@users.noreply.github.com> Co-authored-by: zqwerty <zhuq96@hotmail.com> Co-authored-by: Ryuichi Takanobu <truthless11@gmail.com> Co-authored-by: newRuntieException <wdz15@mails.tsinghua.edu.cn> Co-authored-by: liangrz <liangrz15@mails.tsinghua.edu.cn> Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn> Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local> Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local> Co-authored-by: 罗崚骁 <function2@qq.com> Co-authored-by: mehrad <mehrad@stanford.edu> Co-authored-by: pengbaolin <39398162+pengbaolin@users.noreply.github.com> Co-authored-by: Jinchao Li <38700695+jincli@users.noreply.github.com> Co-authored-by: Shahin Shayandeh <shahins@microsoft.com> Co-authored-by: aaa123git <43716234+aaa123git@users.noreply.github.com> Co-authored-by: Bruno Eidi Nishimoto <bruno_nishimoto@hotmail.com> Co-authored-by: Vojtěch Hudeček <vojta.hudecek@gmail.com> Co-authored-by: zhangzthu <zhangz.goal@gmail.com> Co-authored-by: xw <48146603+xwwwwww@users.noreply.github.com> Co-authored-by: zheng <zheng@zhangzheng-PC.lan> Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com> Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk> --- convlab2/policy/dqn/dqn.py | 33 ++++++-- convlab2/policy/dqn/train.py | 78 ++++++++++++++++++- .../policy/hdsa/multiwoz/transformer/Beam.py | 2 +- 3 files changed, 103 insertions(+), 10 deletions(-) diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py index 39c04ac..3a02929 100644 --- a/convlab2/policy/dqn/dqn.py +++ b/convlab2/policy/dqn/dqn.py @@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay from convlab2.util.train_util import init_logging_handler from convlab2.policy.vector.vector_multiwoz import MultiWozVector +from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot from convlab2.util.file_util import cached_path import zipfile import sys @@ -32,6 +33,8 @@ class DQN(Policy): self.training_iter = cfg['training_iter'] self.training_batch_iter = cfg['training_batch_iter'] self.batch_size = cfg['batch_size'] + self.epsilon = cfg['epsilon_spec']['start'] + self.rule_bot = RuleBasedMultiwozBot() self.gamma = cfg['gamma'] self.is_train = is_train if is_train: @@ -58,9 +61,10 @@ class DQN(Policy): self.loss_fn = nn.MSELoss() def update_memory(self, sample): + self.memory.reset() self.memory.append(sample) - def predict(self, state): + def predict(self, state, warm_up=False): """ Predict an system action given state. Args: @@ -68,12 +72,27 @@ class DQN(Policy): Returns: action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) """ - s_vec = torch.Tensor(self.vector.state_vectorize(state)) - a = self.net.select_action(s_vec.to(device=DEVICE)) - - action = self.vector.action_devectorize(a.numpy()) - - state['system_action'] = action + if warm_up: + action = self.rule_action(state) + state['system_action'] = action + else: + s_vec = torch.Tensor(self.vector.state_vectorize(state)) + a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train) + action = self.vector.action_devectorize(a.numpy()) + state['system_action'] = action + return action + + def rule_action(self, state): + if self.epsilon > np.random.rand(): + a = torch.randint(self.vector.da_dim, (1, )) + # transforms action index to a vector action (one-hot encoding) + a_vec = torch.zeros(self.vector.da_dim) + a_vec[a] = 1. + action = self.vector.action_devectorize(a_vec.numpy()) + else: + # rule-based warm up + action = self.rule_bot.predict(state) + return action def init_session(self): diff --git a/convlab2/policy/dqn/train.py b/convlab2/policy/dqn/train.py index 8ebcf28..2c6412f 100755 --- a/convlab2/policy/dqn/train.py +++ b/convlab2/policy/dqn/train.py @@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz): queue.put([pid, buff]) evt.wait() +def warmupsampler(pid, queue, evt, env, policy, batchsz): + """ + This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple + processes. + :param pid: process id + :param queue: multiprocessing.Queue, to collect sampled data + :param evt: multiprocessing.Event, to keep the process alive + :param env: environment instance + :param policy: policy network, to generate action from current policy + :param batchsz: total sampled items + :return: + """ + buff = Memory() + + # we need to sample batchsz of (state, action, next_state, reward, mask) + # each trajectory contains `trajectory_len` num of items, so we only need to sample + # `batchsz//trajectory_len` num of trajectory totally + # the final sampled number may be larger than batchsz. + + sampled_num = 0 + sampled_traj_num = 0 + traj_len = 50 + real_traj_len = 0 + + while sampled_num < batchsz: + # for each trajectory, we reset the env and get initial state + s = env.reset() + + for t in range(traj_len): + + # [s_dim] => [a_dim] + s_vec = torch.Tensor(policy.vector.state_vectorize(s)) + a = policy.predict(s, warm_up=True) + + # interact with env + next_s, r, done = env.step(a) + + # a flag indicates ending or not + mask = 0 if done else 1 + + # get reward compared to demostrations + next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s)) + + # save to queue + buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask) + + # update per step + s = next_s + real_traj_len = t + + if done: + break -def sample(env, policy, batchsz, process_num): + # this is end of one trajectory + sampled_num += real_traj_len + sampled_traj_num += 1 + # t indicates the valid trajectory length + + # this is end of sampling all batchsz of items. + # when sampling is over, push all buff data into queue + queue.put([pid, buff]) + evt.wait() + + +def sample(env, policy, batchsz, process_num, warm_up=False): """ Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return @@ -119,7 +182,10 @@ def sample(env, policy, batchsz, process_num): processes = [] for i in range(process_num): process_args = (i, queue, evt, env, policy, process_batchsz) - processes.append(mp.Process(target=sampler, args=process_args)) + if warm_up: + processes.append(mp.Process(target=warmupsampler, args=process_args)) + else: + processes.append(mp.Process(target=sampler, args=process_args)) for p in processes: # set the process as daemon, and it will be killed once the main process is stoped. p.daemon = True @@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num): policy.update(epoch) +def warm_start(env, policy, batchsz, epoch, process_num): + # sample data asynchronously + buff = sample(env, policy, batchsz, process_num, warm_up=True) + policy.update_memory(buff) + policy.update(epoch) + + if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--load_path", type=str, default="", help="path of model to load") @@ -170,6 +243,7 @@ if __name__ == '__main__': evaluator = MultiWozEvaluator() env = Environment(None, simulator, None, dst_sys, evaluator) + warm_start(env, policy_sys, args.batchsz, 0, args.process_num) for i in range(args.epoch): update(env, policy_sys, args.batchsz, i, args.process_num) diff --git a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py index 0d9e520..469d8ff 100755 --- a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py +++ b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py @@ -66,7 +66,7 @@ class Beam(object): # bestScoresId is flattened as a (beam x word) array, # so we need to calculate which word and beam each score came from - prev_k = best_scores_id / num_words + prev_k = best_scores_id // num_words self.prev_ks.append(prev_k) self.next_ys.append(best_scores_id - prev_k * num_words) -- GitLab