Skip to content
Snippets Groups Projects
Commit 68d96eda authored by function2's avatar function2
Browse files

Merge branch 'master' of github.com:thu-coai/ConvLab-2

parents dcf8c504 e368deeb
Branches
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
from convlab2.util.file_util import cached_path
import zipfile
import sys
......@@ -32,6 +33,8 @@ class DQN(Policy):
self.training_iter = cfg['training_iter']
self.training_batch_iter = cfg['training_batch_iter']
self.batch_size = cfg['batch_size']
self.epsilon = cfg['epsilon_spec']['start']
self.rule_bot = RuleBasedMultiwozBot()
self.gamma = cfg['gamma']
self.is_train = is_train
if is_train:
......@@ -58,9 +61,10 @@ class DQN(Policy):
self.loss_fn = nn.MSELoss()
def update_memory(self, sample):
self.memory.reset()
self.memory.append(sample)
def predict(self, state):
def predict(self, state, warm_up=False):
"""
Predict an system action given state.
Args:
......@@ -68,14 +72,29 @@ class DQN(Policy):
Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
"""
if warm_up:
action = self.rule_action(state)
state['system_action'] = action
else:
s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.net.select_action(s_vec.to(device=DEVICE))
a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train)
action = self.vector.action_devectorize(a.numpy())
state['system_action'] = action
return action
def rule_action(self, state):
if self.epsilon > np.random.rand():
a = torch.randint(self.vector.da_dim, (1, ))
# transforms action index to a vector action (one-hot encoding)
a_vec = torch.zeros(self.vector.da_dim)
a_vec[a] = 1.
action = self.vector.action_devectorize(a_vec.numpy())
else:
# rule-based warm up
action = self.rule_bot.predict(state)
return action
def init_session(self):
"""
Restore after one session
......
......@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
queue.put([pid, buff])
evt.wait()
def warmupsampler(pid, queue, evt, env, policy, batchsz):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.
sampled_num = 0
sampled_traj_num = 0
traj_len = 50
real_traj_len = 0
while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state
s = env.reset()
for t in range(traj_len):
# [s_dim] => [a_dim]
s_vec = torch.Tensor(policy.vector.state_vectorize(s))
a = policy.predict(s, warm_up=True)
# interact with env
next_s, r, done = env.step(a)
# a flag indicates ending or not
mask = 0 if done else 1
# get reward compared to demostrations
next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
# save to queue
buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
# update per step
s = next_s
real_traj_len = t
if done:
break
def sample(env, policy, batchsz, process_num):
# this is end of one trajectory
sampled_num += real_traj_len
sampled_traj_num += 1
# t indicates the valid trajectory length
# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue.put([pid, buff])
evt.wait()
def sample(env, policy, batchsz, process_num, warm_up=False):
"""
Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return
......@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num):
processes = []
for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz)
if warm_up:
processes.append(mp.Process(target=warmupsampler, args=process_args))
else:
processes.append(mp.Process(target=sampler, args=process_args))
for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped.
......@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
policy.update(epoch)
def warm_start(env, policy, batchsz, epoch, process_num):
# sample data asynchronously
buff = sample(env, policy, batchsz, process_num, warm_up=True)
policy.update_memory(buff)
policy.update(epoch)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--load_path", type=str, default="", help="path of model to load")
......@@ -170,6 +243,7 @@ if __name__ == '__main__':
evaluator = MultiWozEvaluator()
env = Environment(None, simulator, None, dst_sys, evaluator)
warm_start(env, policy_sys, args.batchsz, 0, args.process_num)
for i in range(args.epoch):
update(env, policy_sys, args.batchsz, i, args.process_num)
......@@ -66,7 +66,7 @@ class Beam(object):
# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
prev_k = best_scores_id / num_words
prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment