From e368deeb3d405caf19236fb768360a6517a24fcd Mon Sep 17 00:00:00 2001
From: Carrey Wang <hrwang@se.cuhk.edu.hk>
Date: Thu, 22 Oct 2020 14:19:05 +0800
Subject: [PATCH] Add warmup for DQN and fix minor bugs (#150)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Initial commit

* Update README.md

* sync with commit aa1af0ee81ba591d1cf3c222c9d71963ed1dca98

* add gitignore

* update tutorial

* update mdrg, not use dbPointer

* update mdrg, not use dbPointer

* update mdrg, download before use dbPointer

* update analyzer

* update README

* update tutorial

* Fix dbquery when matching name

* move change to dev branch

* move dbquery change from master to dev branch

* disable travis for now

* do nothing in travis

* do nothing in travis for now

* not deploy now

* add docs

* update .travis.yml

* update .travis.yml

* update .travis.yml

* update rst files

* add alias center for centre in dbquery

* fix the policy training bug

* fix the bug of nan gradient

* add cross-lingual dst data

* Update README.md

* Update README.md

* Add CrossWoz Web support and some minor bug fix (#19)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>

* modify xdst data name

* Translation train on MultiWOZ (Chinese) nad CrossWOZ (English) of SUMBT (#17)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* updata xdst baseline

* Update README.md

* fix allennlp==0.9.0

* Update README.md

* modify build message function for goal generation

* Fix goal generator and dbquery for multiwoz (#32)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* replace attraction type 'mutliple sports' to 'multiple sports', involving only one entity

* add depart and destination constraints for searching db (ignore=False), modify goal generator to draw the values of these two slots from database

* fix bug (#35)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* fix bug #34

* revert changes

* update demo video link

* Update README.md

* some changes in #36 (#37)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* fix bug #34

* revert changes

* revert changes

* some changes of #36

* fix analyzer example.py

* dst/evaluate.py: Use utf-8 encoding

* use transformers library to automate model caching

* Update README.md

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* fix nlu max len

* update travis

* Update run_agent.py

* Create README.md

* Update README.md

* modify human_eval README

* fix sclstm crosswoz import issues

* update travis.yml

* try to fix deploy

* Update README.md

* Update README.md

* Improve agenda policy (#52)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* Update README.md

#53

* Update README.md (#57)

* Improve agenda policy (#60)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* Improve agenda policy (#62)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* fix agenda for 0 choice

* fix sequicityy

* fix sequicityy

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* update setup.py:add tokenizers requirement

* fix typo

* update user nlg template

* Update README.md

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* update Evaluator: check whether final goal satisfies constraints

* update evaluator: check booked entity

* output goal analysis to file

* update goal analysis

* update

* Update analyzer.py

* Fix simulator (#83)

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* remove police and hospital domain in goal generator

* update multiwoz evaluator: adding 'internet/parking-none, 24:** to valid value

* fix nlg template (#88)

* add new_goal_model without police and hospital domain (#89)

* Normalize string comparisons in multiwoz template nlg to be case insensitive (#87)

* normalize template nlg keys to be lower case

* fix slot comparison in multiwoz nlg to be case insensitive

* use value_lower instead of calling .lower() on each comparison

* Add police n hospital (#95)

* add back police and hospital goal

* update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result

* update dbquery and session (#99)

* update dbquery: ? matches all; fix bug in init_session

* update multiwoz_eval, check Ref of booked

* filter domain in final_goal_analyze

Co-authored-by: newRuntieException <wdz15@mails.tsinghua.edu.cn>

* Add dockerfile (#98)

* fix nlg template

* add dockerfile

* include missing packages at setup.py (#102)

* multiwoz dbquery doesnt require mutable constraints (#106)

* Add police n hospital (#107)

* add back police and hospital goal

* update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result

* update user templatenlg

* add test set example for dstc9 (multiwoz_zh, crosswoz_en) (#108)

* Add dockerfile (#110)

* fix nlg template

* add dockerfile

* add package for dockerfile

* update versions

* Update README.md

* Update versions in setup (#111)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* fix sequicityy

* update versions

Co-authored-by: zqwerty <zhuq96@hotmail.com>
Co-authored-by: zhuqi <zqwerty@users.noreply.github.com>

* Update README.md

* Update README.md

* Update README.md

* fix system nlg template bug (#117)

* add 'book' in DST evaluation. (#85)

* Maintenance (#119)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* dstc9 eval

* dstc9 xldst evaluation

* Nlg template fix (#121)

* fix nlg template

* fix user nlg template issue

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* dstc9 xldst evaluation (#122)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* revise xldst evaluation (#124)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* Update dst.py

* update precision, recall, f1 calculation

* minor change

* fix policy evaluation

* Nlg template fix (#127)

* fix nlg template

* fix user nlg template issue

* fix system NLG template

* nlu update and bugfix (#118)

* jointBERT_new avaliable && fix milu dataset_reader && fix jointBERT/tag2id

* remove jointBERT_new

* update milu/multiwoz/nlu.py model_file path

* add metrics in XLDST evaluation (#126)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* add input reqt vals in human eval (#128)

* Maintenance (#129)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* Human (#131)

* change task config

* add final goal logging

* encapsule PipelineAgent internal state interface for return and
replacement

* Maintenance (#132)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix a database typo

* Maintenance (#134)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix goal generator for police domain message

* fix a minor typo in crosswoz database (#133)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* remove low performance baselines (#136)

* Human2 (#137)

* change task config

* add final goal logging

* encapsule PipelineAgent internal state interface for return and
replacement

* fix bug associted with the issue of strange user input

* Fix a bug in TRADE CrossWOZ training (#138)

* add 'book' in DST evaluation.

* Fix TRADE crosswoz training evaluation bug

Co-authored-by: zheng <zheng@zhangzheng-PC.lan>

* Maintenance (#140)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix goal generator for police domain message

* update template NLG

* Add note for deploy web service (#139)

* add 'book' in DST evaluation.

* Fix TRADE crosswoz training evaluation bug

* Add note on deploy

Co-authored-by: zheng <zheng@zhangzheng-PC.lan>

* add value unification

* fix XLDST evaluation (#141)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification

* fix user Nlg template (#142)

* fix system nlg template bug

* fix user nlg issue

* fix white character issue #144

* deal with white charater in XLDST evaluation (#145)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification

* fix white character issue #144

* DQN (#113)

* implemented script to extract all the statistics for all dialogue_act in data

* changed script for actions be compatible to sys_da_voc.txt actions

* multiwoz vector now supports composite actions

* implemented ReplayMemory and EpsilongGreedyPolicy

* implemented a basic version of dqn

* included some comments

* Add DQN Test and Change file structure  (#146)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

* Update server.py

* add test for DQN

* change server

Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>

* update eval

* dump dst eval results

* make value lower

* add progress bar

* fix bug in last commit

* Update policy_agenda_multiwoz.py

* remove unnecessary mapping (#147)

* udpate dstc9 eval

* make value lower

* add warm up for dqn and fix bugs

* rm unrelated files

Co-authored-by: zhuqi <zqwerty@users.noreply.github.com>
Co-authored-by: zqwerty <zhuq96@hotmail.com>
Co-authored-by: Ryuichi Takanobu <truthless11@gmail.com>
Co-authored-by: newRuntieException <wdz15@mails.tsinghua.edu.cn>
Co-authored-by: liangrz <liangrz15@mails.tsinghua.edu.cn>
Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: 罗崚骁 <function2@qq.com>
Co-authored-by: mehrad <mehrad@stanford.edu>
Co-authored-by: pengbaolin <39398162+pengbaolin@users.noreply.github.com>
Co-authored-by: Jinchao Li <38700695+jincli@users.noreply.github.com>
Co-authored-by: Shahin Shayandeh <shahins@microsoft.com>
Co-authored-by: aaa123git <43716234+aaa123git@users.noreply.github.com>
Co-authored-by: Bruno Eidi Nishimoto <bruno_nishimoto@hotmail.com>
Co-authored-by: Vojtěch Hudeček <vojta.hudecek@gmail.com>
Co-authored-by: zhangzthu <zhangz.goal@gmail.com>
Co-authored-by: xw <48146603+xwwwwww@users.noreply.github.com>
Co-authored-by: zheng <zheng@zhangzheng-PC.lan>
Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>
---
 convlab2/policy/dqn/dqn.py                    | 33 ++++++--
 convlab2/policy/dqn/train.py                  | 78 ++++++++++++++++++-
 .../policy/hdsa/multiwoz/transformer/Beam.py  |  2 +-
 3 files changed, 103 insertions(+), 10 deletions(-)

diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py
index 39c04ac..3a02929 100644
--- a/convlab2/policy/dqn/dqn.py
+++ b/convlab2/policy/dqn/dqn.py
@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
 from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
 from convlab2.util.train_util import init_logging_handler
 from convlab2.policy.vector.vector_multiwoz import MultiWozVector
+from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
 from convlab2.util.file_util import cached_path
 import zipfile
 import sys
@@ -32,6 +33,8 @@ class DQN(Policy):
         self.training_iter = cfg['training_iter']
         self.training_batch_iter = cfg['training_batch_iter']
         self.batch_size = cfg['batch_size']
+        self.epsilon = cfg['epsilon_spec']['start']
+        self.rule_bot = RuleBasedMultiwozBot()
         self.gamma = cfg['gamma']
         self.is_train = is_train
         if is_train:
@@ -58,9 +61,10 @@ class DQN(Policy):
         self.loss_fn = nn.MSELoss()
 
     def update_memory(self, sample):
+        self.memory.reset()
         self.memory.append(sample)
         
-    def predict(self, state):
+    def predict(self, state, warm_up=False):
         """
         Predict an system action given state.
         Args:
@@ -68,12 +72,27 @@ class DQN(Policy):
         Returns:
             action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
         """
-        s_vec = torch.Tensor(self.vector.state_vectorize(state))
-        a = self.net.select_action(s_vec.to(device=DEVICE))
-
-        action = self.vector.action_devectorize(a.numpy())
-
-        state['system_action'] = action
+        if warm_up:
+            action = self.rule_action(state)
+            state['system_action'] = action
+        else:
+            s_vec = torch.Tensor(self.vector.state_vectorize(state))
+            a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train)
+            action = self.vector.action_devectorize(a.numpy())
+            state['system_action'] = action
+        return action
+    
+    def rule_action(self, state):
+        if self.epsilon > np.random.rand():
+            a = torch.randint(self.vector.da_dim, (1, ))
+            # transforms action index to a vector action (one-hot encoding)
+            a_vec = torch.zeros(self.vector.da_dim)
+            a_vec[a] = 1.
+            action = self.vector.action_devectorize(a_vec.numpy())
+        else:
+            # rule-based warm up
+            action = self.rule_bot.predict(state)
+        
         return action
 
     def init_session(self):
diff --git a/convlab2/policy/dqn/train.py b/convlab2/policy/dqn/train.py
index 8ebcf28..2c6412f 100755
--- a/convlab2/policy/dqn/train.py
+++ b/convlab2/policy/dqn/train.py
@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
     queue.put([pid, buff])
     evt.wait()
 
+def warmupsampler(pid, queue, evt, env, policy, batchsz):
+    """
+    This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
+    processes.
+    :param pid: process id
+    :param queue: multiprocessing.Queue, to collect sampled data
+    :param evt: multiprocessing.Event, to keep the process alive
+    :param env: environment instance
+    :param policy: policy network, to generate action from current policy
+    :param batchsz: total sampled items
+    :return:
+    """
+    buff = Memory()
+
+    # we need to sample batchsz of (state, action, next_state, reward, mask)
+    # each trajectory contains `trajectory_len` num of items, so we only need to sample
+    # `batchsz//trajectory_len` num of trajectory totally
+    # the final sampled number may be larger than batchsz.
+
+    sampled_num = 0
+    sampled_traj_num = 0
+    traj_len = 50
+    real_traj_len = 0
+
+    while sampled_num < batchsz:
+        # for each trajectory, we reset the env and get initial state
+        s = env.reset()
+
+        for t in range(traj_len):
+
+            # [s_dim] => [a_dim]
+            s_vec = torch.Tensor(policy.vector.state_vectorize(s))
+            a = policy.predict(s, warm_up=True)
+
+            # interact with env
+            next_s, r, done = env.step(a)
+
+            # a flag indicates ending or not
+            mask = 0 if done else 1
+
+            # get reward compared to demostrations
+            next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
+
+            # save to queue
+            buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
+
+            # update per step
+            s = next_s
+            real_traj_len = t
+
+            if done:
+                break
 
-def sample(env, policy, batchsz, process_num):
+        # this is end of one trajectory
+        sampled_num += real_traj_len
+        sampled_traj_num += 1
+        # t indicates the valid trajectory length
+
+    # this is end of sampling all batchsz of items.
+    # when sampling is over, push all buff data into queue
+    queue.put([pid, buff])
+    evt.wait()
+
+
+def sample(env, policy, batchsz, process_num, warm_up=False):
     """
     Given batchsz number of task, the batchsz will be splited equally to each processes
     and when processes return, it merge all data and return
@@ -119,7 +182,10 @@ def sample(env, policy, batchsz, process_num):
     processes = []
     for i in range(process_num):
         process_args = (i, queue, evt, env, policy, process_batchsz)
-        processes.append(mp.Process(target=sampler, args=process_args))
+        if warm_up:
+            processes.append(mp.Process(target=warmupsampler, args=process_args))
+        else:
+            processes.append(mp.Process(target=sampler, args=process_args))
     for p in processes:
         # set the process as daemon, and it will be killed once the main process is stoped.
         p.daemon = True
@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
     policy.update(epoch)
 
 
+def warm_start(env, policy, batchsz, epoch, process_num):
+    # sample data asynchronously
+    buff = sample(env, policy, batchsz, process_num, warm_up=True)
+    policy.update_memory(buff)
+    policy.update(epoch)
+
+
 if __name__ == '__main__':
     parser = ArgumentParser()
     parser.add_argument("--load_path", type=str, default="", help="path of model to load")
@@ -170,6 +243,7 @@ if __name__ == '__main__':
     evaluator = MultiWozEvaluator()
     env = Environment(None, simulator, None, dst_sys, evaluator)
 
+    warm_start(env, policy_sys, args.batchsz, 0, args.process_num)
 
     for i in range(args.epoch):
         update(env, policy_sys, args.batchsz, i, args.process_num)
diff --git a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py
index 0d9e520..469d8ff 100755
--- a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py
+++ b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py
@@ -66,7 +66,7 @@ class Beam(object):
 
         # bestScoresId is flattened as a (beam x word) array,
         # so we need to calculate which word and beam each score came from
-        prev_k = best_scores_id / num_words
+        prev_k = best_scores_id // num_words
         self.prev_ks.append(prev_k)
         self.next_ys.append(best_scores_id - prev_k * num_words)
 
-- 
GitLab