Skip to content
Snippets Groups Projects
Commit 68d96eda authored by function2's avatar function2
Browse files

Merge branch 'master' of github.com:thu-coai/ConvLab-2

parents dcf8c504 e368deeb
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy ...@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
from convlab2.util.train_util import init_logging_handler from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector from convlab2.policy.vector.vector_multiwoz import MultiWozVector
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
from convlab2.util.file_util import cached_path from convlab2.util.file_util import cached_path
import zipfile import zipfile
import sys import sys
...@@ -32,6 +33,8 @@ class DQN(Policy): ...@@ -32,6 +33,8 @@ class DQN(Policy):
self.training_iter = cfg['training_iter'] self.training_iter = cfg['training_iter']
self.training_batch_iter = cfg['training_batch_iter'] self.training_batch_iter = cfg['training_batch_iter']
self.batch_size = cfg['batch_size'] self.batch_size = cfg['batch_size']
self.epsilon = cfg['epsilon_spec']['start']
self.rule_bot = RuleBasedMultiwozBot()
self.gamma = cfg['gamma'] self.gamma = cfg['gamma']
self.is_train = is_train self.is_train = is_train
if is_train: if is_train:
...@@ -58,9 +61,10 @@ class DQN(Policy): ...@@ -58,9 +61,10 @@ class DQN(Policy):
self.loss_fn = nn.MSELoss() self.loss_fn = nn.MSELoss()
def update_memory(self, sample): def update_memory(self, sample):
self.memory.reset()
self.memory.append(sample) self.memory.append(sample)
def predict(self, state): def predict(self, state, warm_up=False):
""" """
Predict an system action given state. Predict an system action given state.
Args: Args:
...@@ -68,14 +72,29 @@ class DQN(Policy): ...@@ -68,14 +72,29 @@ class DQN(Policy):
Returns: Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
""" """
if warm_up:
action = self.rule_action(state)
state['system_action'] = action
else:
s_vec = torch.Tensor(self.vector.state_vectorize(state)) s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.net.select_action(s_vec.to(device=DEVICE)) a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train)
action = self.vector.action_devectorize(a.numpy()) action = self.vector.action_devectorize(a.numpy())
state['system_action'] = action state['system_action'] = action
return action return action
def rule_action(self, state):
if self.epsilon > np.random.rand():
a = torch.randint(self.vector.da_dim, (1, ))
# transforms action index to a vector action (one-hot encoding)
a_vec = torch.zeros(self.vector.da_dim)
a_vec[a] = 1.
action = self.vector.action_devectorize(a_vec.numpy())
else:
# rule-based warm up
action = self.rule_bot.predict(state)
return action
def init_session(self): def init_session(self):
""" """
Restore after one session Restore after one session
......
...@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz): ...@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
queue.put([pid, buff]) queue.put([pid, buff])
evt.wait() evt.wait()
def warmupsampler(pid, queue, evt, env, policy, batchsz):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.
sampled_num = 0
sampled_traj_num = 0
traj_len = 50
real_traj_len = 0
while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state
s = env.reset()
for t in range(traj_len):
# [s_dim] => [a_dim]
s_vec = torch.Tensor(policy.vector.state_vectorize(s))
a = policy.predict(s, warm_up=True)
# interact with env
next_s, r, done = env.step(a)
# a flag indicates ending or not
mask = 0 if done else 1
# get reward compared to demostrations
next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
# save to queue
buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
# update per step
s = next_s
real_traj_len = t
if done:
break
def sample(env, policy, batchsz, process_num): # this is end of one trajectory
sampled_num += real_traj_len
sampled_traj_num += 1
# t indicates the valid trajectory length
# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue.put([pid, buff])
evt.wait()
def sample(env, policy, batchsz, process_num, warm_up=False):
""" """
Given batchsz number of task, the batchsz will be splited equally to each processes Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return and when processes return, it merge all data and return
...@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num): ...@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num):
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz) process_args = (i, queue, evt, env, policy, process_batchsz)
if warm_up:
processes.append(mp.Process(target=warmupsampler, args=process_args))
else:
processes.append(mp.Process(target=sampler, args=process_args)) processes.append(mp.Process(target=sampler, args=process_args))
for p in processes: for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped. # set the process as daemon, and it will be killed once the main process is stoped.
...@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num): ...@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
policy.update(epoch) policy.update(epoch)
def warm_start(env, policy, batchsz, epoch, process_num):
# sample data asynchronously
buff = sample(env, policy, batchsz, process_num, warm_up=True)
policy.update_memory(buff)
policy.update(epoch)
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--load_path", type=str, default="", help="path of model to load") parser.add_argument("--load_path", type=str, default="", help="path of model to load")
...@@ -170,6 +243,7 @@ if __name__ == '__main__': ...@@ -170,6 +243,7 @@ if __name__ == '__main__':
evaluator = MultiWozEvaluator() evaluator = MultiWozEvaluator()
env = Environment(None, simulator, None, dst_sys, evaluator) env = Environment(None, simulator, None, dst_sys, evaluator)
warm_start(env, policy_sys, args.batchsz, 0, args.process_num)
for i in range(args.epoch): for i in range(args.epoch):
update(env, policy_sys, args.batchsz, i, args.process_num) update(env, policy_sys, args.batchsz, i, args.process_num)
...@@ -66,7 +66,7 @@ class Beam(object): ...@@ -66,7 +66,7 @@ class Beam(object):
# bestScoresId is flattened as a (beam x word) array, # bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from # so we need to calculate which word and beam each score came from
prev_k = best_scores_id / num_words prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k) self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words) self.next_ys.append(best_scores_id - prev_k * num_words)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment