diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py
index 39c04ac13c6e3e1cee03b591c058b86e15e0de47..3a029290ca9cb2591c2ee5ae44e832568c5eb623 100644
--- a/convlab2/policy/dqn/dqn.py
+++ b/convlab2/policy/dqn/dqn.py
@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
 from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
 from convlab2.util.train_util import init_logging_handler
 from convlab2.policy.vector.vector_multiwoz import MultiWozVector
+from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
 from convlab2.util.file_util import cached_path
 import zipfile
 import sys
@@ -32,6 +33,8 @@ class DQN(Policy):
         self.training_iter = cfg['training_iter']
         self.training_batch_iter = cfg['training_batch_iter']
         self.batch_size = cfg['batch_size']
+        self.epsilon = cfg['epsilon_spec']['start']
+        self.rule_bot = RuleBasedMultiwozBot()
         self.gamma = cfg['gamma']
         self.is_train = is_train
         if is_train:
@@ -58,9 +61,10 @@ class DQN(Policy):
         self.loss_fn = nn.MSELoss()
 
     def update_memory(self, sample):
+        self.memory.reset()
         self.memory.append(sample)
         
-    def predict(self, state):
+    def predict(self, state, warm_up=False):
         """
         Predict an system action given state.
         Args:
@@ -68,12 +72,27 @@ class DQN(Policy):
         Returns:
             action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
         """
-        s_vec = torch.Tensor(self.vector.state_vectorize(state))
-        a = self.net.select_action(s_vec.to(device=DEVICE))
-
-        action = self.vector.action_devectorize(a.numpy())
-
-        state['system_action'] = action
+        if warm_up:
+            action = self.rule_action(state)
+            state['system_action'] = action
+        else:
+            s_vec = torch.Tensor(self.vector.state_vectorize(state))
+            a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train)
+            action = self.vector.action_devectorize(a.numpy())
+            state['system_action'] = action
+        return action
+    
+    def rule_action(self, state):
+        if self.epsilon > np.random.rand():
+            a = torch.randint(self.vector.da_dim, (1, ))
+            # transforms action index to a vector action (one-hot encoding)
+            a_vec = torch.zeros(self.vector.da_dim)
+            a_vec[a] = 1.
+            action = self.vector.action_devectorize(a_vec.numpy())
+        else:
+            # rule-based warm up
+            action = self.rule_bot.predict(state)
+        
         return action
 
     def init_session(self):
diff --git a/convlab2/policy/dqn/train.py b/convlab2/policy/dqn/train.py
index 8ebcf280a73e20bc289a8860713fea78c878b312..2c6412f8e5a3f5d65e38a1fb6c53e5058e94ab83 100755
--- a/convlab2/policy/dqn/train.py
+++ b/convlab2/policy/dqn/train.py
@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
     queue.put([pid, buff])
     evt.wait()
 
+def warmupsampler(pid, queue, evt, env, policy, batchsz):
+    """
+    This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
+    processes.
+    :param pid: process id
+    :param queue: multiprocessing.Queue, to collect sampled data
+    :param evt: multiprocessing.Event, to keep the process alive
+    :param env: environment instance
+    :param policy: policy network, to generate action from current policy
+    :param batchsz: total sampled items
+    :return:
+    """
+    buff = Memory()
+
+    # we need to sample batchsz of (state, action, next_state, reward, mask)
+    # each trajectory contains `trajectory_len` num of items, so we only need to sample
+    # `batchsz//trajectory_len` num of trajectory totally
+    # the final sampled number may be larger than batchsz.
+
+    sampled_num = 0
+    sampled_traj_num = 0
+    traj_len = 50
+    real_traj_len = 0
+
+    while sampled_num < batchsz:
+        # for each trajectory, we reset the env and get initial state
+        s = env.reset()
+
+        for t in range(traj_len):
+
+            # [s_dim] => [a_dim]
+            s_vec = torch.Tensor(policy.vector.state_vectorize(s))
+            a = policy.predict(s, warm_up=True)
+
+            # interact with env
+            next_s, r, done = env.step(a)
+
+            # a flag indicates ending or not
+            mask = 0 if done else 1
+
+            # get reward compared to demostrations
+            next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
+
+            # save to queue
+            buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
+
+            # update per step
+            s = next_s
+            real_traj_len = t
+
+            if done:
+                break
 
-def sample(env, policy, batchsz, process_num):
+        # this is end of one trajectory
+        sampled_num += real_traj_len
+        sampled_traj_num += 1
+        # t indicates the valid trajectory length
+
+    # this is end of sampling all batchsz of items.
+    # when sampling is over, push all buff data into queue
+    queue.put([pid, buff])
+    evt.wait()
+
+
+def sample(env, policy, batchsz, process_num, warm_up=False):
     """
     Given batchsz number of task, the batchsz will be splited equally to each processes
     and when processes return, it merge all data and return
@@ -119,7 +182,10 @@ def sample(env, policy, batchsz, process_num):
     processes = []
     for i in range(process_num):
         process_args = (i, queue, evt, env, policy, process_batchsz)
-        processes.append(mp.Process(target=sampler, args=process_args))
+        if warm_up:
+            processes.append(mp.Process(target=warmupsampler, args=process_args))
+        else:
+            processes.append(mp.Process(target=sampler, args=process_args))
     for p in processes:
         # set the process as daemon, and it will be killed once the main process is stoped.
         p.daemon = True
@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
     policy.update(epoch)
 
 
+def warm_start(env, policy, batchsz, epoch, process_num):
+    # sample data asynchronously
+    buff = sample(env, policy, batchsz, process_num, warm_up=True)
+    policy.update_memory(buff)
+    policy.update(epoch)
+
+
 if __name__ == '__main__':
     parser = ArgumentParser()
     parser.add_argument("--load_path", type=str, default="", help="path of model to load")
@@ -170,6 +243,7 @@ if __name__ == '__main__':
     evaluator = MultiWozEvaluator()
     env = Environment(None, simulator, None, dst_sys, evaluator)
 
+    warm_start(env, policy_sys, args.batchsz, 0, args.process_num)
 
     for i in range(args.epoch):
         update(env, policy_sys, args.batchsz, i, args.process_num)
diff --git a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py
index 0d9e5201998e0d9249b4a0cdaca9a0c782c22f90..469d8ff943c45e424045762e0bdd42c4a3d9f2f3 100755
--- a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py
+++ b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py
@@ -66,7 +66,7 @@ class Beam(object):
 
         # bestScoresId is flattened as a (beam x word) array,
         # so we need to calculate which word and beam each score came from
-        prev_k = best_scores_id / num_words
+        prev_k = best_scores_id // num_words
         self.prev_ks.append(prev_k)
         self.next_ys.append(best_scores_id - prev_k * num_words)