Skip to content
Snippets Groups Projects
Unverified Commit 3fd04ce2 authored by Bruno Eidi Nishimoto's avatar Bruno Eidi Nishimoto Committed by GitHub
Browse files

DQN (#113)

* implemented script to extract all the statistics for all dialogue_act in data

* changed script for actions be compatible to sys_da_voc.txt actions

* multiwoz vector now supports composite actions

* implemented ReplayMemory and EpsilongGreedyPolicy

* implemented a basic version of dqn

* included some comments
parent 24f3c2da
Branches
No related tags found
No related merge requests found
from convlab2.policy.dqn.dqn import DQN
{
"batch_size": 16,
"gamma": 0.99,
"lr": 0.001,
"save_dir": "save",
"log_dir": "log",
"save_per_epoch": 5,
"training_iter": 10,
"training_batch_iter": 3,
"h_dim": 100,
"hv_dim": 50,
"memory_size": 5000,
"epsilon_spec": {
"start": 0.1,
"end": 0.0,
"end_epoch": 200
},
"load": "save/best",
"vocab_size": 500
}
# -*- coding: utf-8 -*-
import torch
from torch import optim
from torch import nn
import numpy as np
import logging
import os
import json
import copy
from convlab2.policy.policy import Policy
from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector
from convlab2.util.file_util import cached_path
import zipfile
import sys
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(root_dir)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DQN(Policy):
def __init__(self, is_train=False, dataset='Multiwoz'):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f)
self.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['save_dir'])
self.save_per_epoch = cfg['save_per_epoch']
self.training_iter = cfg['training_iter']
self.training_batch_iter = cfg['training_batch_iter']
self.batch_size = cfg['batch_size']
self.gamma = cfg['gamma']
self.is_train = is_train
if is_train:
init_logging_handler(os.path.join(os.path.dirname(os.path.abspath(__file__)), cfg['log_dir']))
# construct multiwoz vector
if dataset == 'Multiwoz':
voc_file = os.path.join(root_dir, 'data/multiwoz/sys_da_voc.txt')
voc_opp_file = os.path.join(root_dir, 'data/multiwoz/usr_da_voc.txt')
self.vector = MultiWozVector(voc_file, voc_opp_file, composite_actions=True, vocab_size=cfg['vocab_size'])
#replay memory
self.memory = MemoryReplay(cfg['memory_size'])
self.net = EpsilonGreedyPolicy(self.vector.state_dim, cfg['hv_dim'], self.vector.da_dim, cfg['epsilon_spec']).to(device=DEVICE)
self.target_net = copy.deepcopy(self.net)
self.online_net = self.target_net
self.eval_net = self.target_net
if is_train:
self.net_optim = optim.Adam(self.net.parameters(), lr=cfg['lr'])
self.loss_fn = nn.MSELoss()
def update_memory(self, sample):
self.memory.append(sample)
def predict(self, state):
"""
Predict an system action given state.
Args:
state (dict): Dialog state. Please refer to util/state.py
Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
"""
s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.net.select_action(s_vec.to(device=DEVICE))
action = self.vector.action_devectorize(a.numpy())
state['system_action'] = action
return action
def init_session(self):
"""
Restore after one session
"""
self.memory.reset()
def calc_q_loss(self, batch):
'''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
s = torch.from_numpy(np.stack(batch.state)).to(device=DEVICE)
a = torch.from_numpy(np.stack(batch.action)).to(device=DEVICE)
r = torch.from_numpy(np.stack(batch.reward)).to(device=DEVICE)
next_s = torch.from_numpy(np.stack(batch.next_state)).to(device=DEVICE)
mask = torch.Tensor(np.stack(batch.mask)).to(device=DEVICE)
q_preds = self.net(s)
with torch.no_grad():
# Use online_net to select actions in next state
online_next_q_preds = self.online_net(next_s)
# Use eval_net to calculate next_q_preds for actions chosen by online_net
next_q_preds = self.eval_net(next_s)
act_q_preds = q_preds.gather(-1, a.argmax(-1).long().unsqueeze(-1)).squeeze(-1)
online_actions = online_next_q_preds.argmax(dim=-1, keepdim=True)
max_next_q_preds = next_q_preds.gather(-1, online_actions).squeeze(-1)
max_q_targets = r + self.gamma * mask * max_next_q_preds
q_loss = self.loss_fn(act_q_preds, max_q_targets)
return q_loss
def update(self, epoch):
total_loss = 0.
for i in range(self.training_iter):
round_loss = 0.
# 1. batch a sample from memory
batch = self.memory.get_batch(batch_size=self.batch_size)
for _ in range(self.training_batch_iter):
# 2. calculate the Q loss
loss = self.calc_q_loss(batch)
# 3. make a optimization step
self.net_optim.zero_grad()
loss.backward()
self.net_optim.step()
round_loss += loss.item()
logging.debug('<<dialog policy dqn>> epoch {}, iteration {}, loss {}'.format(epoch, i, round_loss / self.training_batch_iter))
total_loss += round_loss
total_loss /= (self.training_batch_iter * self.training_iter)
logging.debug('<<dialog policy dqn>> epoch {}, total_loss {}'.format(epoch, total_loss))
# update the epsilon value
self.net.update_epsilon(epoch)
# update the target network
self.target_net.load_state_dict(self.net.state_dict())
if (epoch+1) % self.save_per_epoch == 0:
self.save(self.save_dir, epoch)
def save(self, directory, epoch):
if not os.path.exists(directory):
os.makedirs(directory)
torch.save(self.net.state_dict(), directory + '/' + str(epoch) + '_dqn.pol.mdl')
logging.info('<<dialog policy>> epoch {}: saved network to mdl'.format(epoch))
def load(self, filename):
dqn_mdl_candidates = [
filename + '.dqn.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.dqn.mdl'),
]
for dqn_mdl in dqn_mdl_candidates:
if os.path.exists(dqn_mdl):
self.net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE))
self.target_net.load_state_dict(torch.load(dqn_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(dqn_mdl))
break
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 14 16:14:07 2019
@author: truthless
"""
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import numpy as np
import torch
import logging
from torch import multiprocessing as mp
from convlab2.dialog_agent.agent import PipelineAgent
from convlab2.dialog_agent.env import Environment
from convlab2.nlu.svm.multiwoz import SVMNLU
from convlab2.dst.rule.multiwoz import RuleDST
from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.policy.dqn import DQN
from convlab2.policy.rlmodule import Memory, Transition
from convlab2.nlg.template.multiwoz import TemplateNLG
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from argparse import ArgumentParser
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
mp = mp.get_context('spawn')
except RuntimeError:
pass
def sampler(pid, queue, evt, env, policy, batchsz):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.
sampled_num = 0
sampled_traj_num = 0
traj_len = 50
real_traj_len = 0
while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state
s = env.reset()
for t in range(traj_len):
# [s_dim] => [a_dim]
s_vec = torch.Tensor(policy.vector.state_vectorize(s))
a = policy.predict(s)
# interact with env
next_s, r, done = env.step(a)
# a flag indicates ending or not
mask = 0 if done else 1
# get reward compared to demostrations
next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))
# save to queue
buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)
# update per step
s = next_s
real_traj_len = t
if done:
break
# this is end of one trajectory
sampled_num += real_traj_len
sampled_traj_num += 1
# t indicates the valid trajectory length
# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue.put([pid, buff])
evt.wait()
def sample(env, policy, batchsz, process_num):
"""
Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return
:param env:
:param policy:
:param batchsz:
:param process_num:
:return: batch
"""
# batchsz will be splitted into each process,
# final batchsz maybe larger than batchsz parameters
process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
# buffer to save all data
queue = mp.Queue()
# start processes for pid in range(1, processnum)
# if processnum = 1, this part will be ignored.
# when save tensor in Queue, the process should keep alive till Queue.get(),
# please refer to : https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847
# however still some problem on CUDA tensors on multiprocessing queue,
# please refer to : https://discuss.pytorch.org/t/cuda-tensors-on-multiprocessing-queue/28626
# so just transform tensors into numpy, then put them into queue.
evt = mp.Event()
processes = []
for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz)
processes.append(mp.Process(target=sampler, args=process_args))
for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped.
p.daemon = True
p.start()
# we need to get the first Memory object and then merge others Memory use its append function.
pid0, buff0 = queue.get()
for _ in range(1, process_num):
pid, buff_ = queue.get()
buff0.append(buff_) # merge current Memory into buff0
evt.set()
# now buff saves all the sampled data
buff = buff0
return buff
def update(env, policy, batchsz, epoch, process_num):
# sample data asynchronously
buff = sample(env, policy, batchsz, process_num)
policy.update_memory(buff)
policy.update(epoch)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--load_path", type=str, default="", help="path of model to load")
parser.add_argument("--batchsz", type=int, default=100, help="batch size of trajactory sampling")
parser.add_argument("--epoch", type=int, default=200, help="number of epochs to train")
parser.add_argument("--process_num", type=int, default=3, help="number of processes of trajactory sampling")
args = parser.parse_args()
# simple rule DST
dst_sys = RuleDST()
policy_sys = DQN(True)
policy_sys.load(args.load_path)
# not use dst
dst_usr = None
# rule policy
policy_usr = RulePolicy(character='usr')
# assemble
simulator = PipelineAgent(None, None, policy_usr, None, 'user')
evaluator = MultiWozEvaluator()
env = Environment(None, simulator, None, dst_sys, evaluator)
for i in range(args.epoch):
update(env, policy_sys, args.batchsz, i, args.process_num)
......@@ -56,6 +56,63 @@ class DiscretePolicy(nn.Module):
return log_prob
class EpsilonGreedyPolicy(nn.Module):
def __init__(self, s_dim, h_dim, a_dim, epsilon_spec={'start': 0.1, 'end': 0.0, 'end_epoch': 200}):
super(EpsilonGreedyPolicy, self).__init__()
self.net = nn.Sequential(nn.Linear(s_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, h_dim),
nn.ReLU(),
nn.Linear(h_dim, a_dim))
self.epsilon = epsilon_spec['start']
self.start = epsilon_spec['start']
self.end = epsilon_spec['end']
self.end_epoch = epsilon_spec['end_epoch']
self.a_dim = a_dim
def forward(self, s):
# [b, s_dim] => [b, a_dim]
a_weights = self.net(s)
return a_weights
def select_action(self, s, is_train=True):
"""
:param s: [s_dim]
:return: [1]
"""
# forward to get action probs
# [s_dim] => [a_dim]
if is_train:
if self.epsilon > np.random.rand():
# select a random action
a = torch.randint(self.a_dim, (1, ))
else:
a = self._greedy_action(s)
else:
a = self._greedy_action(s)
# transforms action index to a vector action (one-hot encoding)
a_vec = torch.zeros(self.a_dim)
a_vec[a] = 1.
return a_vec
def update_epsilon(self, epoch):
# Linear decay
a = -float(self.start - self.end) / self.end_epoch
b = float(self.start)
self.epsilon = max(self.end, a * float(epoch) + b)
def _greedy_action(self, s):
"""
Select a greedy action
"""
a_weights = self.forward(s)
return a_weights.argmax(0, True)
class MultiDiscretePolicy(nn.Module):
def __init__(self, s_dim, h_dim, a_dim):
......@@ -224,3 +281,42 @@ class Memory(object):
def __len__(self):
return len(self.memory)
class MemoryReplay(object):
"""
The difference to class Memory is that MemoryReplay has a limited size.
It is mainly used for off-policy algorithms.
"""
def __init__(self, max_size):
self.memory = []
self.index = 0
self.max_size = max_size
def push(self, *args):
"""Saves a transition."""
if len(self.memory) < self.max_size:
self.memory.append(None)
self.memory[self.index] = Transition(*args)
self.index = (self.index + 1) % self.max_size
def get_batch(self, batch_size=None):
if batch_size is None:
return Transition(*zip(*self.memory))
else:
random_batch = random.sample(self.memory, batch_size)
return Transition(*zip(*random_batch))
def append(self, new_memory):
for transition in new_memory.memory:
if len(self.memory) < self.max_size:
self.memory.append(None)
self.memory[self.index] = transition
self.index = (self.index + 1) % self.max_size
def reset(self):
self.memory = []
self.index = 0
def __len__(self):
return len(self.memory)
......@@ -20,16 +20,22 @@ mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'na
'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
DEFAULT_INTENT_FILEPATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/trackable_intent.json'
)
class MultiWozVector(Vector):
def __init__(self, voc_file, voc_opp_file, character='sys',
intent_file=os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/trackable_intent.json')):
intent_file=DEFAULT_INTENT_FILEPATH,
composite_actions=False,
vocab_size=500):
self.belief_domains = ['Attraction', 'Restaurant', 'Train', 'Hotel', 'Taxi', 'Hospital', 'Police']
self.db_domains = ['Attraction', 'Restaurant', 'Train', 'Hotel']
self.composite_actions = composite_actions
self.vocab_size = vocab_size
with open(intent_file) as f:
intents = json.load(f)
......@@ -41,10 +47,31 @@ class MultiWozVector(Vector):
self.da_voc = f.read().splitlines()
with open(voc_opp_file) as f:
self.da_voc_opp = f.read().splitlines()
if self.composite_actions:
self.load_composite_actions()
self.character = character
self.generate_dict()
self.cur_domain = None
def load_composite_actions(self):
"""
load the composite actions to self.da_voc
"""
composite_actions_filepath = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/da_slot_cnt.json')
with open(composite_actions_filepath, 'r') as f:
composite_actions_stats = json.load(f)
for action in composite_actions_stats:
if len(action.split(';')) > 1:
# append only composite actions as single actions are already in self.da_voc
self.da_voc.append(action)
if len(self.da_voc) == self.vocab_size:
break
def generate_dict(self):
"""
init the dict for mapping state/action into vector
......@@ -195,6 +222,11 @@ class MultiWozVector(Vector):
Dialog act
"""
act_array = []
if self.composite_actions:
act_idx = np.argmax(action_vec)
act_array = self.vec2act[act_idx].split(';')
else:
for i, idx in enumerate(action_vec):
if idx == 1:
act_array.append(self.vec2act[i])
......@@ -213,6 +245,14 @@ class MultiWozVector(Vector):
action = delexicalize_da(action, self.requestable)
action = flat_da(action)
act_vec = np.zeros(self.da_dim)
if self.composite_actions:
composite_action = ';'.join(action)
for act in self.act2vec:
if set(action) == set(act.split(';')):
act_vec[self.act2vec[act]] = 1.
break
else:
for da in action:
if da in self.act2vec:
act_vec[self.act2vec[da]] = 1.
......
This diff is collapsed.
import json
import zipfile
from collections import Counter, OrderedDict
def read_zipped_json(filepath, filename):
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
def extract_act(data, counter):
for dialogs in list(data.values()):
for turn, meta in enumerate(dialogs['log']):
if turn % 2 == 0:
# usr turn
continue
action = ''
if meta['dialog_act']:
for act, slots in meta['dialog_act'].items():
for slot in slots:
action += f'{act}-'
if slot[1] == '?' or slot[0] == 'none':
action += f'{slot[0]}-{slot[1]};'
else:
action += f'{slot[0]}-1;'
counter.update([action[:-1]])
if __name__ == '__main__':
counter = Counter()
for s in ['train', 'val', 'test']:
data = read_zipped_json(s + '.json.zip', s + '.json')
extract_act(data, counter)
with open('da_slot_cnt.json', 'w') as f:
json.dump(OrderedDict(counter.most_common()), f, indent=2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment