Select Git revision
SemanticBeliefTracking.rst
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
FeudalBBQNPolicy.py 17.93 KiB
###############################################################################
# PyDial: Multi-domain Statistical Spoken Dialogue System Software
###############################################################################
#
# Copyright 2015 - 2019
# Cambridge University Engineering Department Dialogue Systems Group
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
###############################################################################
'''
DQNPolicy.py - deep Q network policy
==================================================
Author: Chris Tegho and Pei-Hao (Eddy) Su (Copyright CUED Dialogue Systems Group 2016)
.. seealso:: CUED Imports/Dependencies:
import :class:`Policy`
import :class:`utils.ContextLogger`
.. warning::
Documentation not done.
************************
'''
import copy
import os
import json
import numpy as np
import pickle as pickle
import random
import sys
import utils
from utils.Settings import config as cfg
from utils import ContextLogger, DiaAct, DialogueState
import ontology.FlatOntologyManager as FlatOnt
# from theano_dialogue.util.tool import *
import tensorflow as tf
from policy.DRL.replay_bufferVanilla import ReplayBuffer
from policy.DRL.replay_prioritisedVanilla import ReplayPrioritised
import policy.DRL.utils as drlutils
from policy.DRL import bdqn as bbqn
import policy.Policy
import policy.SummaryAction
import policy.BBQNPolicy
from policy.Policy import TerminalAction, TerminalState
from policy.feudalRL.DIP_parametrisation import DIP_state, padded_state
logger = utils.ContextLogger.getLogger('')
# --- for flattening the belief --- #
domainUtil = FlatOnt.FlatDomainOntology('CamRestaurants')
class FeudalBBQNPolicy(policy.BBQNPolicy.BBQNPolicy):
'''Derived from :class:`BBQNPolicy`
'''
def __init__(self, in_policy_file, out_policy_file, domainString='CamRestaurants', is_training=False,
action_names=None, slot=None):
super(FeudalBBQNPolicy, self).__init__(in_policy_file, out_policy_file, domainString, is_training)
tf.reset_default_graph()
self.domainString = domainString
self.domainUtil = FlatOnt.FlatDomainOntology(self.domainString)
self.in_policy_file = in_policy_file
self.out_policy_file = out_policy_file
self.is_training = is_training
self.accum_belief = []
self.prev_state_check = None
self.episode_ave_max_q = []
self.capacity *= 4 #set the capacity for episode methods, multiply it to adjust to turn based methods
self.slot = slot
# init session
self.sess = tf.Session()
with tf.device("/cpu:0"):
np.random.seed(self.randomseed)
tf.set_random_seed(self.randomseed)
# initialise an replay buffer
if self.replay_type == 'vanilla':
self.episodes[self.domainString] = ReplayBuffer(self.capacity, self.minibatch_size, self.randomseed)
elif self.replay_type == 'prioritized':
self.episodes[self.domainString] = ReplayPrioritised(self.capacity, self.minibatch_size,
self.randomseed)
# replay_buffer = ReplayBuffer(self.capacity, self.randomseed)
# self.episodes = []
self.samplecount = 0
self.episodecount = 0
# construct the models
self.state_dim = 89 # current DIP state dim
self.summaryaction = policy.SummaryAction.SummaryAction(domainString)
self.action_names = action_names
self.action_dim = len(self.action_names)
action_bound = len(self.action_names)
self.stats = [0 for _ in range(self.action_dim)]
self.stdVar = []
self.meanVar = []
self.stdMean = []
self.meanMean = []
self.td_error = []
self.td_errorVar = []
self.target_update_freq = 1
if cfg.has_option('bbqnpolicy', 'target_update_freq'):
self.target_update_freq = cfg.get('bbqnpolicy', 'target_update_freq')
#feudal params
self.features = 'dip'
self.sd_enc_size = 25
self.si_enc_size = 50
self.dropout_rate = 0.
if cfg.has_option('feudalpolicy', 'features'):
self.features = cfg.get('feudalpolicy', 'features')
if cfg.has_option('feudalpolicy', 'sd_enc_size'):
self.sd_enc_size = cfg.getint('feudalpolicy', 'sd_enc_size')
if cfg.has_option('feudalpolicy', 'si_enc_size'):
self.si_enc_size = cfg.getint('feudalpolicy', 'si_enc_size')
if cfg.has_option('feudalpolicy', 'dropout_rate') and self.is_training:
self.dropout_rate = cfg.getfloat('feudalpolicy', 'dropout_rate')
self.actfreq_ds = False
if cfg.has_option('feudalpolicy', 'actfreq_ds'):
self.actfreq_ds = cfg.getboolean('feudalpolicy', 'actfreq_ds')
if self.features == 'dip':
if self.actfreq_ds:
if self.domainString == 'CamRestaurants':
self.state_dim += 16
elif self.domainString == 'SFRestaurants':
self.state_dim += 25
elif self.domainString == 'Laptops11':
self.state_dim += 40
self.bbqn = bbqn.DeepQNetwork(self.sess, self.state_dim, self.action_dim, self.learning_rate, self.tau,
action_bound, self.architecture, self.h1_size, self.h2_size,
self.n_samples,
self.minibatch_size, self.sigma_prior, self.n_batches, self.stddev_var_mu,
self.stddev_var_logsigma, self.mean_log_sigma, self.importance_sampling,
self.alpha_divergence, self.alpha, self.sigma_eps)
elif self.features == 'learned' or self.features == 'rnn':
si_state_dim = 72
if self.actfreq_ds:
if self.domainString == 'CamRestaurants':
si_state_dim += 16
elif self.domainString == 'SFRestaurants':
si_state_dim += 25
elif self.domainString == 'Laptops11':
si_state_dim += 40
if self.domainString == 'CamRestaurants':
sd_state_dim = 94
elif self.domainString == 'SFRestaurants':
sd_state_dim = 158
elif self.domainString == 'Laptops11':
sd_state_dim = 13
else:
logger.error(
'Domain {} not implemented in feudal-DQN yet') # just find out the size of sd_state_dim for the new domain
if self.features == 'rnn':
arch = 'rnn'
self.state_dim = si_state_dim + sd_state_dim
self.bbqn = bbqn.RNNBBQNetwork(self.sess, si_state_dim, sd_state_dim, self.action_dim, self.learning_rate,
self.tau, action_bound, arch, self.h1_size, self.h2_size, self.n_samples,
self.minibatch_size, self.sigma_prior, self.n_batches, self.stddev_var_mu,
self.stddev_var_logsigma, self.mean_log_sigma, self.importance_sampling,
self.alpha_divergence, self.alpha, self.sigma_eps, sd_enc_size=self.sd_enc_size,
si_enc_size=self.sd_enc_size, dropout_rate=self.dropout_rate, slot=slot)
else:
arch = 'vanilla'
self.state_dim = si_state_dim + sd_state_dim
self.bbqn = bbqn.NNBBQNetwork(self.sess, si_state_dim, sd_state_dim, self.action_dim, self.learning_rate,
self.tau, action_bound, arch, self.h1_size, self.h2_size, self.n_samples,
self.minibatch_size, self.sigma_prior, self.n_batches, self.stddev_var_mu,
self.stddev_var_logsigma, self.mean_log_sigma, self.importance_sampling,
self.alpha_divergence, self.alpha, self.sigma_eps, sd_enc_size=self.sd_enc_size,
si_enc_size=self.sd_enc_size, dropout_rate=self.dropout_rate, slot=slot)
else:
logger.error('features "{}" not implemented'.format(self.features))
# when all models are defined, init all variables
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
self.loadPolicy(self.in_policy_file)
print('loaded replay size: ', self.episodes[self.domainString].size())
self.bbqn.update_target_network()
def record(self, reward, domainInControl=None, weight=None, state=None, action=None, exec_mask=None):
if domainInControl is None:
domainInControl = self.domainString
if self.actToBeRecorded is None:
# self.actToBeRecorded = self.lastSystemAction
self.actToBeRecorded = self.summaryAct
if state is None:
state = self.prevbelief
if action is None:
action = self.actToBeRecorded
cState, cAction = state, action
reward /= 20.0
cur_cState = np.vstack([np.expand_dims(x, 0) for x in [cState]])
cur_action_q = self.bbqn.predict(cur_cState)
cur_target_q = self.bbqn.predict_target(cur_cState)
if exec_mask is not None:
admissible = np.add(cur_target_q, np.array(exec_mask))
else:
admissible = cur_target_q
Q_s_t_a_t_ = cur_action_q[0][cAction]
gamma_Q_s_tplu1_maxa_ = self.gamma * np.max(admissible)
if weight == None:
if self.replay_type == 'vanilla':
self.episodes[domainInControl].record(state=cState, \
state_ori=state, action=cAction, reward=reward)
elif self.replay_type == 'prioritized':
# heuristically assign 0.0 to Q_s_t_a_t_ and Q_s_tplu1_maxa_, doesn't matter as it is not used
if True:
# if self.samplecount >= self.capacity:
self.episodes[domainInControl].record(state=cState, \
state_ori=state, action=cAction, reward=reward, \
Q_s_t_a_t_=Q_s_t_a_t_,
gamma_Q_s_tplu1_maxa_=gamma_Q_s_tplu1_maxa_, uniform=False)
else:
self.episodes[domainInControl].record(state=cState, \
state_ori=state, action=cAction, reward=reward, \
Q_s_t_a_t_=Q_s_t_a_t_,
gamma_Q_s_tplu1_maxa_=gamma_Q_s_tplu1_maxa_, uniform=True)
else:
self.episodes[domainInControl].record(state=cState, state_ori=state, action=cAction, reward=reward,
ma_weight=weight)
self.actToBeRecorded = None
self.samplecount += 1
return
def finalizeRecord(self, reward, domainInControl=None):
if domainInControl is None:
domainInControl = self.domainString
if self.episodes[domainInControl] is None:
logger.warning("record attempted to be finalized for domain where nothing has been recorded before")
return
# normalising total return to -1~1
# if reward == 0:
# reward = -20.0
reward /= 20.0
"""
if reward == 20.0:
reward = 1.0
else:
reward = -0.5
"""
# reward = float(reward+10.0)/40.0
terminal_state, terminal_action = self.convertStateAction(TerminalState(), TerminalAction())
if self.replay_type == 'vanilla':
self.episodes[domainInControl].record(state=terminal_state, \
state_ori=TerminalState(), action=terminal_action, reward=reward,
terminal=True)
elif self.replay_type == 'prioritized':
# heuristically assign 0.0 to Q_s_t_a_t_ and Q_s_tplu1_maxa_, doesn't matter as it is not used
if True:
# if self.samplecount >= self.capacity:
self.episodes[domainInControl].record(state=terminal_state, \
state_ori=TerminalState(), action=terminal_action, reward=reward, \
Q_s_t_a_t_=0.0, gamma_Q_s_tplu1_maxa_=0.0, uniform=False,
terminal=True)
else:
self.episodes[domainInControl].record(state=terminal_state, \
state_ori=TerminalState(), action=terminal_action, reward=reward, \
Q_s_t_a_t_=0.0, gamma_Q_s_tplu1_maxa_=0.0, uniform=True,
terminal=True)
def convertStateAction(self, state, action):
'''
'''
if isinstance(state, TerminalState):
return [0] * 89, action
else:
if self.features == 'learned' or self.features == 'rnn':
dip_state = padded_state(state.domainStates[state.currentdomain], self.domainString)
else:
dip_state = DIP_state(state.domainStates[state.currentdomain], self.domainString)
action_name = self.actions.action_names[action]
act_slot = 'general'
for slot in dip_state.slots:
if slot in action_name:
act_slot = slot
flat_belief = dip_state.get_beliefStateVec(act_slot)
self.prev_state_check = flat_belief
return flat_belief, action
def nextAction(self, beliefstate):
'''
select next action
:param beliefstate:
:param hyps:
:returns: (int) next summary action
'''
if self.exploration_type == 'e-greedy':
# epsilon greedy
if self.is_training and utils.Settings.random.rand() < self.epsilon:
action_Q = np.random.rand(len(self.action_names))
else:
action_Q = self.bbqn.predict(np.reshape(beliefstate, (1, len(beliefstate)))) # + (1. / (1. + i + j))
self.episode_ave_max_q.append(np.max(action_Q))
# return the Q vect, the action will be converted in the feudal policy
return action_Q
def train(self):
'''
call this function when the episode ends
'''
if not self.is_training:
logger.info("Not in training mode")
return
else:
logger.info("Update dqn policy parameters.")
self.episodecount += 1
logger.info("Sample Num so far: %s" % (self.samplecount))
logger.info("Episode Num so far: %s" % (self.episodecount))
if self.samplecount >= self.minibatch_size * 10 and self.episodecount % self.training_frequency == 0:
logger.info('start training...')
s_batch, s_ori_batch, a_batch, r_batch, s2_batch, s2_ori_batch, t_batch, idx_batch, _ = \
self.episodes[self.domainString].sample_batch()
s_batch = np.vstack([np.expand_dims(x, 0) for x in s_batch])
s2_batch = np.vstack([np.expand_dims(x, 0) for x in s2_batch])
a_batch_one_hot = np.eye(self.action_dim, self.action_dim)[a_batch]
action_q = self.bbqn.predict_dip(s2_batch, a_batch_one_hot)
target_q = self.bbqn.predict_target_dip(s2_batch, a_batch_one_hot)
# print 'action Q and target Q:', action_q, target_q
y_i = []
for k in range(min(self.minibatch_size, self.episodes[self.domainString].size())):
Q_bootstrap_label = 0
if t_batch[k]:
Q_bootstrap_label = r_batch[k]
else:
if self.q_update == 'single':
belief = s2_ori_batch[k]
execMask = [0.0] * len(self.action_names) # TODO: find out how to compute the mask here, or save it when recording the state
execMask[-1] = -sys.maxsize
action_Q = target_q[k]
admissible = np.add(action_Q, np.array(execMask))
Q_bootstrap_label = r_batch[k] + self.gamma * np.max(admissible)
y_i.append(Q_bootstrap_label)
# Update the critic given the targets
reshaped_yi = np.vstack([np.expand_dims(x, 0) for x in y_i])
predicted_q_value, _, currentLoss, logLikelihood, varFC2, meanFC2, td_error, KL_div = self.bbqn.train(s_batch, a_batch_one_hot, reshaped_yi, self.episodecount)
if self.episodecount % self.target_update_freq == 0:
self.bbqn.update_target_network()
if self.episodecount % self.save_step == 0:
self.savePolicyInc() # self.out_policy_file)
# END OF FILE