diff --git a/convlab2/policy/dqn/dqn.py b/convlab2/policy/dqn/dqn.py index 39c04ac13c6e3e1cee03b591c058b86e15e0de47..3a029290ca9cb2591c2ee5ae44e832568c5eb623 100644 --- a/convlab2/policy/dqn/dqn.py +++ b/convlab2/policy/dqn/dqn.py @@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay from convlab2.util.train_util import init_logging_handler from convlab2.policy.vector.vector_multiwoz import MultiWozVector +from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot from convlab2.util.file_util import cached_path import zipfile import sys @@ -32,6 +33,8 @@ class DQN(Policy): self.training_iter = cfg['training_iter'] self.training_batch_iter = cfg['training_batch_iter'] self.batch_size = cfg['batch_size'] + self.epsilon = cfg['epsilon_spec']['start'] + self.rule_bot = RuleBasedMultiwozBot() self.gamma = cfg['gamma'] self.is_train = is_train if is_train: @@ -58,9 +61,10 @@ class DQN(Policy): self.loss_fn = nn.MSELoss() def update_memory(self, sample): + self.memory.reset() self.memory.append(sample) - def predict(self, state): + def predict(self, state, warm_up=False): """ Predict an system action given state. Args: @@ -68,12 +72,27 @@ class DQN(Policy): Returns: action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) """ - s_vec = torch.Tensor(self.vector.state_vectorize(state)) - a = self.net.select_action(s_vec.to(device=DEVICE)) - - action = self.vector.action_devectorize(a.numpy()) - - state['system_action'] = action + if warm_up: + action = self.rule_action(state) + state['system_action'] = action + else: + s_vec = torch.Tensor(self.vector.state_vectorize(state)) + a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train) + action = self.vector.action_devectorize(a.numpy()) + state['system_action'] = action + return action + + def rule_action(self, state): + if self.epsilon > np.random.rand(): + a = torch.randint(self.vector.da_dim, (1, )) + # transforms action index to a vector action (one-hot encoding) + a_vec = torch.zeros(self.vector.da_dim) + a_vec[a] = 1. + action = self.vector.action_devectorize(a_vec.numpy()) + else: + # rule-based warm up + action = self.rule_bot.predict(state) + return action def init_session(self): diff --git a/convlab2/policy/dqn/train.py b/convlab2/policy/dqn/train.py index 8ebcf280a73e20bc289a8860713fea78c878b312..2c6412f8e5a3f5d65e38a1fb6c53e5058e94ab83 100755 --- a/convlab2/policy/dqn/train.py +++ b/convlab2/policy/dqn/train.py @@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz): queue.put([pid, buff]) evt.wait() +def warmupsampler(pid, queue, evt, env, policy, batchsz): + """ + This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple + processes. + :param pid: process id + :param queue: multiprocessing.Queue, to collect sampled data + :param evt: multiprocessing.Event, to keep the process alive + :param env: environment instance + :param policy: policy network, to generate action from current policy + :param batchsz: total sampled items + :return: + """ + buff = Memory() + + # we need to sample batchsz of (state, action, next_state, reward, mask) + # each trajectory contains `trajectory_len` num of items, so we only need to sample + # `batchsz//trajectory_len` num of trajectory totally + # the final sampled number may be larger than batchsz. + + sampled_num = 0 + sampled_traj_num = 0 + traj_len = 50 + real_traj_len = 0 + + while sampled_num < batchsz: + # for each trajectory, we reset the env and get initial state + s = env.reset() + + for t in range(traj_len): + + # [s_dim] => [a_dim] + s_vec = torch.Tensor(policy.vector.state_vectorize(s)) + a = policy.predict(s, warm_up=True) + + # interact with env + next_s, r, done = env.step(a) + + # a flag indicates ending or not + mask = 0 if done else 1 + + # get reward compared to demostrations + next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s)) + + # save to queue + buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask) + + # update per step + s = next_s + real_traj_len = t + + if done: + break -def sample(env, policy, batchsz, process_num): + # this is end of one trajectory + sampled_num += real_traj_len + sampled_traj_num += 1 + # t indicates the valid trajectory length + + # this is end of sampling all batchsz of items. + # when sampling is over, push all buff data into queue + queue.put([pid, buff]) + evt.wait() + + +def sample(env, policy, batchsz, process_num, warm_up=False): """ Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return @@ -119,7 +182,10 @@ def sample(env, policy, batchsz, process_num): processes = [] for i in range(process_num): process_args = (i, queue, evt, env, policy, process_batchsz) - processes.append(mp.Process(target=sampler, args=process_args)) + if warm_up: + processes.append(mp.Process(target=warmupsampler, args=process_args)) + else: + processes.append(mp.Process(target=sampler, args=process_args)) for p in processes: # set the process as daemon, and it will be killed once the main process is stoped. p.daemon = True @@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num): policy.update(epoch) +def warm_start(env, policy, batchsz, epoch, process_num): + # sample data asynchronously + buff = sample(env, policy, batchsz, process_num, warm_up=True) + policy.update_memory(buff) + policy.update(epoch) + + if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--load_path", type=str, default="", help="path of model to load") @@ -170,6 +243,7 @@ if __name__ == '__main__': evaluator = MultiWozEvaluator() env = Environment(None, simulator, None, dst_sys, evaluator) + warm_start(env, policy_sys, args.batchsz, 0, args.process_num) for i in range(args.epoch): update(env, policy_sys, args.batchsz, i, args.process_num) diff --git a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py index 0d9e5201998e0d9249b4a0cdaca9a0c782c22f90..469d8ff943c45e424045762e0bdd42c4a3d9f2f3 100755 --- a/convlab2/policy/hdsa/multiwoz/transformer/Beam.py +++ b/convlab2/policy/hdsa/multiwoz/transformer/Beam.py @@ -66,7 +66,7 @@ class Beam(object): # bestScoresId is flattened as a (beam x word) array, # so we need to calculate which word and beam each score came from - prev_k = best_scores_id / num_words + prev_k = best_scores_id // num_words self.prev_ks.append(prev_k) self.next_ys.append(best_scores_id - prev_k * num_words)