Skip to content
Snippets Groups Projects
Select Git revision
  • 3393a064f8f5977bc741378b89015baa4e6fcfda
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

test_SetSUMBT-VtraceRNN-TemplateNLG.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    FeudalBBQNPolicy.py 17.93 KiB
    ###############################################################################
    # PyDial: Multi-domain Statistical Spoken Dialogue System Software
    ###############################################################################
    #
    # Copyright 2015 - 2019
    # Cambridge University Engineering Department Dialogue Systems Group
    #
    # 
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    #
    ###############################################################################
    
    '''
    DQNPolicy.py - deep Q network policy
    ==================================================
    
    Author: Chris Tegho and Pei-Hao (Eddy) Su  (Copyright CUED Dialogue Systems Group 2016)
    
    .. seealso:: CUED Imports/Dependencies: 
    
        import :class:`Policy`
        import :class:`utils.ContextLogger`
    
    .. warning::
            Documentation not done.
    
    
    ************************
    
    '''
    
    import copy
    import os
    import json
    import numpy as np
    import pickle as pickle
    import random
    import sys
    import utils
    from utils.Settings import config as cfg
    from utils import ContextLogger, DiaAct, DialogueState
    
    import ontology.FlatOntologyManager as FlatOnt
    # from theano_dialogue.util.tool import *
    
    import tensorflow as tf
    from policy.DRL.replay_bufferVanilla import ReplayBuffer
    from policy.DRL.replay_prioritisedVanilla import ReplayPrioritised
    import policy.DRL.utils as drlutils
    from policy.DRL import bdqn as bbqn
    import policy.Policy
    import policy.SummaryAction
    import policy.BBQNPolicy
    from policy.Policy import TerminalAction, TerminalState
    from policy.feudalRL.DIP_parametrisation import DIP_state, padded_state
    
    logger = utils.ContextLogger.getLogger('')
    
    # --- for flattening the belief --- # 
    domainUtil = FlatOnt.FlatDomainOntology('CamRestaurants')
    
    
    class FeudalBBQNPolicy(policy.BBQNPolicy.BBQNPolicy):
        '''Derived from :class:`BBQNPolicy`
        '''
        def __init__(self, in_policy_file, out_policy_file, domainString='CamRestaurants', is_training=False,
                     action_names=None, slot=None):
            super(FeudalBBQNPolicy, self).__init__(in_policy_file, out_policy_file, domainString, is_training)
    
            tf.reset_default_graph()
    
            self.domainString = domainString
            self.domainUtil = FlatOnt.FlatDomainOntology(self.domainString)
            self.in_policy_file = in_policy_file
            self.out_policy_file = out_policy_file
            self.is_training = is_training
            self.accum_belief = []
    
            self.prev_state_check = None
    
            self.episode_ave_max_q = []
    
            self.capacity *= 4 #set the capacity for episode methods, multiply it to adjust to turn based methods
            self.slot = slot
    
            # init session
            self.sess = tf.Session()
            with tf.device("/cpu:0"):
    
                np.random.seed(self.randomseed)
                tf.set_random_seed(self.randomseed)
    
                # initialise an replay buffer
                if self.replay_type == 'vanilla':
                    self.episodes[self.domainString] = ReplayBuffer(self.capacity, self.minibatch_size, self.randomseed)
                elif self.replay_type == 'prioritized':
                    self.episodes[self.domainString] = ReplayPrioritised(self.capacity, self.minibatch_size,
                                                                         self.randomseed)
                # replay_buffer = ReplayBuffer(self.capacity, self.randomseed)
                # self.episodes = []
                self.samplecount = 0
                self.episodecount = 0
    
                # construct the models
                self.state_dim = 89  # current DIP state dim
                self.summaryaction = policy.SummaryAction.SummaryAction(domainString)
                self.action_names = action_names
                self.action_dim = len(self.action_names)
                action_bound = len(self.action_names)
                self.stats = [0 for _ in range(self.action_dim)]
                self.stdVar = []
                self.meanVar = []
                self.stdMean = []
                self.meanMean = []
                self.td_error = []
                self.td_errorVar = []
    
                self.target_update_freq = 1
                if cfg.has_option('bbqnpolicy', 'target_update_freq'):
                    self.target_update_freq = cfg.get('bbqnpolicy', 'target_update_freq')
    
                #feudal params
                self.features = 'dip'
                self.sd_enc_size = 25
                self.si_enc_size = 50
                self.dropout_rate = 0.
                if cfg.has_option('feudalpolicy', 'features'):
                    self.features = cfg.get('feudalpolicy', 'features')
                if cfg.has_option('feudalpolicy', 'sd_enc_size'):
                    self.sd_enc_size = cfg.getint('feudalpolicy', 'sd_enc_size')
                if cfg.has_option('feudalpolicy', 'si_enc_size'):
                    self.si_enc_size = cfg.getint('feudalpolicy', 'si_enc_size')
                if cfg.has_option('feudalpolicy', 'dropout_rate') and self.is_training:
                    self.dropout_rate = cfg.getfloat('feudalpolicy', 'dropout_rate')
                self.actfreq_ds = False
                if cfg.has_option('feudalpolicy', 'actfreq_ds'):
                    self.actfreq_ds = cfg.getboolean('feudalpolicy', 'actfreq_ds')
    
                if self.features == 'dip':
                    if self.actfreq_ds:
                        if self.domainString == 'CamRestaurants':
                            self.state_dim += 16
                        elif self.domainString == 'SFRestaurants':
                            self.state_dim += 25
                        elif self.domainString == 'Laptops11':
                            self.state_dim += 40
    
                    self.bbqn = bbqn.DeepQNetwork(self.sess, self.state_dim, self.action_dim, self.learning_rate, self.tau,
                                                  action_bound, self.architecture, self.h1_size, self.h2_size,
                                                  self.n_samples,
                                                  self.minibatch_size, self.sigma_prior, self.n_batches, self.stddev_var_mu,
                                                  self.stddev_var_logsigma, self.mean_log_sigma, self.importance_sampling,
                                                  self.alpha_divergence, self.alpha, self.sigma_eps)
                elif self.features == 'learned' or self.features == 'rnn':
                    si_state_dim = 72
                    if self.actfreq_ds:
                        if self.domainString == 'CamRestaurants':
                            si_state_dim += 16
                        elif self.domainString == 'SFRestaurants':
                            si_state_dim += 25
                        elif self.domainString == 'Laptops11':
                            si_state_dim += 40
                    if self.domainString == 'CamRestaurants':
                        sd_state_dim = 94
                    elif self.domainString == 'SFRestaurants':
                        sd_state_dim = 158
                    elif self.domainString == 'Laptops11':
                        sd_state_dim = 13
                    else:
                        logger.error(
                            'Domain {} not implemented in feudal-DQN yet')  # just find out the size of sd_state_dim for the new domain
                    if self.features == 'rnn':
                        arch = 'rnn'
                        self.state_dim = si_state_dim + sd_state_dim
                        self.bbqn = bbqn.RNNBBQNetwork(self.sess, si_state_dim, sd_state_dim, self.action_dim, self.learning_rate,
                                                      self.tau, action_bound, arch, self.h1_size, self.h2_size, self.n_samples,
                                                      self.minibatch_size, self.sigma_prior, self.n_batches, self.stddev_var_mu,
                                                      self.stddev_var_logsigma, self.mean_log_sigma, self.importance_sampling,
                                                      self.alpha_divergence, self.alpha, self.sigma_eps, sd_enc_size=self.sd_enc_size,
                                                       si_enc_size=self.sd_enc_size, dropout_rate=self.dropout_rate, slot=slot)
                    else:
                        arch = 'vanilla'
                        self.state_dim = si_state_dim + sd_state_dim
                        self.bbqn = bbqn.NNBBQNetwork(self.sess, si_state_dim, sd_state_dim, self.action_dim, self.learning_rate,
                                                      self.tau, action_bound, arch, self.h1_size, self.h2_size, self.n_samples,
                                                      self.minibatch_size, self.sigma_prior, self.n_batches, self.stddev_var_mu,
                                                      self.stddev_var_logsigma, self.mean_log_sigma, self.importance_sampling,
                                                      self.alpha_divergence, self.alpha, self.sigma_eps, sd_enc_size=self.sd_enc_size,
                                                       si_enc_size=self.sd_enc_size, dropout_rate=self.dropout_rate, slot=slot)
                else:
                    logger.error('features "{}" not implemented'.format(self.features))
    
    
    
                # when all models are defined, init all variables
                init_op = tf.global_variables_initializer()
                self.sess.run(init_op)
    
                self.loadPolicy(self.in_policy_file)
                print('loaded replay size: ', self.episodes[self.domainString].size())
    
                self.bbqn.update_target_network()
    
        def record(self, reward, domainInControl=None, weight=None, state=None, action=None, exec_mask=None):
            if domainInControl is None:
                domainInControl = self.domainString
            if self.actToBeRecorded is None:
                # self.actToBeRecorded = self.lastSystemAction
                self.actToBeRecorded = self.summaryAct
    
            if state is None:
                state = self.prevbelief
            if action is None:
                action = self.actToBeRecorded
    
            cState, cAction = state, action
    
            reward /= 20.0
    
            cur_cState = np.vstack([np.expand_dims(x, 0) for x in [cState]])
            cur_action_q = self.bbqn.predict(cur_cState)
            cur_target_q = self.bbqn.predict_target(cur_cState)
    
            if exec_mask is not None:
                admissible = np.add(cur_target_q, np.array(exec_mask))
            else:
                admissible = cur_target_q
    
            Q_s_t_a_t_ = cur_action_q[0][cAction]
            gamma_Q_s_tplu1_maxa_ = self.gamma * np.max(admissible)
    
            if weight == None:
                if self.replay_type == 'vanilla':
                    self.episodes[domainInControl].record(state=cState, \
                                                          state_ori=state, action=cAction, reward=reward)
                elif self.replay_type == 'prioritized':
                    # heuristically assign 0.0 to Q_s_t_a_t_ and Q_s_tplu1_maxa_, doesn't matter as it is not used
                    if True:
                        # if self.samplecount >= self.capacity:
                        self.episodes[domainInControl].record(state=cState, \
                                                              state_ori=state, action=cAction, reward=reward, \
                                                              Q_s_t_a_t_=Q_s_t_a_t_,
                                                              gamma_Q_s_tplu1_maxa_=gamma_Q_s_tplu1_maxa_, uniform=False)
                    else:
                        self.episodes[domainInControl].record(state=cState, \
                                                              state_ori=state, action=cAction, reward=reward, \
                                                              Q_s_t_a_t_=Q_s_t_a_t_,
                                                              gamma_Q_s_tplu1_maxa_=gamma_Q_s_tplu1_maxa_, uniform=True)
    
            else:
                self.episodes[domainInControl].record(state=cState, state_ori=state, action=cAction, reward=reward,
                                                      ma_weight=weight)
    
            self.actToBeRecorded = None
            self.samplecount += 1
            return
    
        def finalizeRecord(self, reward, domainInControl=None):
            if domainInControl is None:
                domainInControl = self.domainString
            if self.episodes[domainInControl] is None:
                logger.warning("record attempted to be finalized for domain where nothing has been recorded before")
                return
    
            # normalising total return to -1~1
            # if reward == 0:
            #    reward = -20.0
            reward /= 20.0
            """
            if reward == 20.0:
                reward = 1.0
            else:
                reward = -0.5
            """
            # reward = float(reward+10.0)/40.0
    
            terminal_state, terminal_action = self.convertStateAction(TerminalState(), TerminalAction())
    
            if self.replay_type == 'vanilla':
                self.episodes[domainInControl].record(state=terminal_state, \
                                                      state_ori=TerminalState(), action=terminal_action, reward=reward,
                                                      terminal=True)
            elif self.replay_type == 'prioritized':
                # heuristically assign 0.0 to Q_s_t_a_t_ and Q_s_tplu1_maxa_, doesn't matter as it is not used
                if True:
                    # if self.samplecount >= self.capacity:
                    self.episodes[domainInControl].record(state=terminal_state, \
                                                          state_ori=TerminalState(), action=terminal_action, reward=reward, \
                                                          Q_s_t_a_t_=0.0, gamma_Q_s_tplu1_maxa_=0.0, uniform=False,
                                                          terminal=True)
                else:
                    self.episodes[domainInControl].record(state=terminal_state, \
                                                          state_ori=TerminalState(), action=terminal_action, reward=reward, \
                                                          Q_s_t_a_t_=0.0, gamma_Q_s_tplu1_maxa_=0.0, uniform=True,
                                                          terminal=True)
    
        def convertStateAction(self, state, action):
            '''
    
            '''
            if isinstance(state, TerminalState):
                return [0] * 89, action
    
            else:
                if self.features == 'learned' or self.features == 'rnn':
                    dip_state = padded_state(state.domainStates[state.currentdomain], self.domainString)
                else:
                    dip_state = DIP_state(state.domainStates[state.currentdomain], self.domainString)
                action_name = self.actions.action_names[action]
                act_slot = 'general'
                for slot in dip_state.slots:
                    if slot in action_name:
                        act_slot = slot
                flat_belief = dip_state.get_beliefStateVec(act_slot)
                self.prev_state_check = flat_belief
    
                return flat_belief, action
    
        def nextAction(self, beliefstate):
            '''
            select next action
    
            :param beliefstate:
            :param hyps:
            :returns: (int) next summary action
            '''
    
            if self.exploration_type == 'e-greedy':
                # epsilon greedy
                if self.is_training and utils.Settings.random.rand() < self.epsilon:
                    action_Q = np.random.rand(len(self.action_names))
                else:
                    action_Q = self.bbqn.predict(np.reshape(beliefstate, (1, len(beliefstate))))  # + (1. / (1. + i + j))
    
                    self.episode_ave_max_q.append(np.max(action_Q))
    
            # return the Q vect, the action will be converted in the feudal policy
            return action_Q
    
    
        def train(self):
            '''
            call this function when the episode ends
            '''
    
            if not self.is_training:
                logger.info("Not in training mode")
                return
            else:
                logger.info("Update dqn policy parameters.")
    
            self.episodecount += 1
            logger.info("Sample Num so far: %s" % (self.samplecount))
            logger.info("Episode Num so far: %s" % (self.episodecount))
    
            if self.samplecount >= self.minibatch_size * 10 and self.episodecount % self.training_frequency == 0:
                logger.info('start training...')
    
                s_batch, s_ori_batch, a_batch, r_batch, s2_batch, s2_ori_batch, t_batch, idx_batch, _ = \
                    self.episodes[self.domainString].sample_batch()
    
                s_batch = np.vstack([np.expand_dims(x, 0) for x in s_batch])
                s2_batch = np.vstack([np.expand_dims(x, 0) for x in s2_batch])
    
                a_batch_one_hot = np.eye(self.action_dim, self.action_dim)[a_batch]
                action_q = self.bbqn.predict_dip(s2_batch, a_batch_one_hot)
                target_q = self.bbqn.predict_target_dip(s2_batch, a_batch_one_hot)
                # print 'action Q and target Q:', action_q, target_q
    
                y_i = []
                for k in range(min(self.minibatch_size, self.episodes[self.domainString].size())):
                    Q_bootstrap_label = 0
                    if t_batch[k]:
                        Q_bootstrap_label = r_batch[k]
                    else:
                        if self.q_update == 'single':
                            belief = s2_ori_batch[k]
                            execMask = [0.0] * len(self.action_names)  # TODO: find out how to compute the mask here, or save it when recording the state
                            execMask[-1] = -sys.maxsize
                            action_Q = target_q[k]
                            admissible = np.add(action_Q, np.array(execMask))
                            Q_bootstrap_label = r_batch[k] + self.gamma * np.max(admissible)
    
                    y_i.append(Q_bootstrap_label)
    
                # Update the critic given the targets
                reshaped_yi = np.vstack([np.expand_dims(x, 0) for x in y_i])
    
                predicted_q_value, _, currentLoss, logLikelihood, varFC2, meanFC2, td_error, KL_div = self.bbqn.train(s_batch, a_batch_one_hot, reshaped_yi, self.episodecount)
    
            if self.episodecount % self.target_update_freq == 0:
                self.bbqn.update_target_network()
            if self.episodecount % self.save_step == 0:
                self.savePolicyInc()  # self.out_policy_file)
    
    
    # END OF FILE