From ca5d4f6ebfb8b947cdcc2dde12a97871fb3fe8d7 Mon Sep 17 00:00:00 2001 From: Carel van Niekerk <vniekerk.carel@gmail.com> Date: Fri, 8 May 2020 09:39:55 +0200 Subject: [PATCH] Add missing curiosity module --- curiosity/README_curiosity | 58 +++++++++ curiosity/__init__.py | 0 curiosity/curiosity_module.py | 77 +++++++++++ curiosity/model_prediction_curiosity.py | 162 ++++++++++++++++++++++++ curiosity/pretraining.py | 156 +++++++++++++++++++++++ 5 files changed, 453 insertions(+) create mode 100644 curiosity/README_curiosity create mode 100644 curiosity/__init__.py create mode 100644 curiosity/curiosity_module.py create mode 100644 curiosity/model_prediction_curiosity.py create mode 100644 curiosity/pretraining.py diff --git a/curiosity/README_curiosity b/curiosity/README_curiosity new file mode 100644 index 0000000..8034385 --- /dev/null +++ b/curiosity/README_curiosity @@ -0,0 +1,58 @@ +RL with Curiosity Rewards for DM + +by Paula +================================================================== + +The curiosity reward option enables to use belief-state prediction error as an additional reward for +policy learning via RL. + +Following files are affected by the curiosity reward option: + 1) pydial.py + 2) utils/Settings.py + 3) evaluation/SuccessEvaluator.py + 4) policy/ACERpolicy.py or DQNpolicy.py (other policies do not include curiosity model training option yet) + + additional files (included in the curiosity directory): + 5) model_prediction_curiosity.py + 6) curiosity_module.py + 7) pretraining.py + +Configuration: +The curiosityreward option can be chosen in the evaluation section in the configuration file. +Curiosity rewards are used when curiosityreward = True. This option overwrites any epsilon-greedy exploration settings +and the policy is used greedily at all times with no random exploration. +Feature size belief-state feature encoding can be chosen as feat_size (by default it is set to 200 if not specified in +the configuration file). #Todo: option to choose layer2 size? +The name of the pre-training model to be used is to be specified under the variable model_name in the config file. +Always make sure that the pre-traing model has the same feature size as the model to be trained, +else dimensions will not fit when reading in the model. +Reward scale is default set to be 1, which is determined to work best for environment 3 and 4. +The reward scale can be set using variable rew_scale in the configuration. + +Example from configuration file: + + ###### Evaluation parameters ###### + [eval] + rewardvenuerecommended = 0 + penaliseallturns = True + wrongvenuepenalty = 0 + notmentionedvaluepenalty = 0 + successmeasure = objective + successreward = 20 + curiosityreward = True + feat_size = 200 + rew_scale = 1 + model_name = trained_curiosityacer-shuffle22_feat200 + pre_trg = False + +Pre-training: +1) Pre-training data collection is done by setting pre_trg = True in the config file. +Note for pre-training data collection: Only about 100 dialogues are needed for pre-training, +no need to run training for longer. +The dialogues for pre-training should be simple and do not have to match the policy used later on. +Further pre-training dialogues do not have to be successful for pre-training. +2) To run pre-training and build an initial curiosity model the script pretraining.py is used. +Settings for pre-traing are changed in the script directly. Before running make sure pre-training data file names +and model_name are specified. +After running the new model can be used, specifying model_name in the config file for training new policies. + diff --git a/curiosity/__init__.py b/curiosity/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/curiosity/curiosity_module.py b/curiosity/curiosity_module.py new file mode 100644 index 0000000..c0a358f --- /dev/null +++ b/curiosity/curiosity_module.py @@ -0,0 +1,77 @@ +############################################################################### +# idea adapted from: +# Deepak Pathak, Pulkit Agrawal, Alexei A. Efros, Trevor Darrell +# University of California, Berkeley +# Curiosity-driven Exploration by Self-supervised Prediction + +# added by Paula +############################################################################### + +import numpy as np +import os +import tensorflow as tf + +from curiosity import model_prediction_curiosity as mpc +from utils import Settings + + +class Curious(object): + def __init__(self): + tf.reset_default_graph() + self.learning_rate = 0.001 + self.forward_loss_wt = 0.2 + self.feat_size = 200 + self.num_actions = 16 + self.num_belief_states = 268 + self.layer2 = 200 + + if Settings.config.has_option("eval", "feat_size"): + self.feat_size = Settings.config.getint("eval", "feat_size") + + with tf.variable_scope('curiosity', reuse=tf.AUTO_REUSE): + self.predictor = mpc.StateActionPredictor(self.num_belief_states, self.num_actions, + feature_size=self.feat_size, layer2=self.layer2) + + self.predloss = self.predictor.invloss * (1 - self.forward_loss_wt) + \ + self.predictor.forwardloss * self.forward_loss_wt + + self.optimizer = tf.train.AdamOptimizer(self.learning_rate) + self.optimize = self.optimizer.minimize(self.predloss) + # self.optimize = self.optimizer.minimize(self.predictor.forwardloss) # when no feature encoding is used! + self.cnt = 1 + + self.sess2 = tf.Session() + self.sess2.run(tf.global_variables_initializer()) + all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) + self.saver = tf.train.Saver(var_list=[v for v in all_variables if "Variab" not in v.name and "beta" not in v.name]) + + def training(self, state_vec, prev_state_vec, action_1hot): + _, predictionloss = self.sess2.run([self.optimize, self.predloss], + feed_dict={self.predictor.s1: prev_state_vec, + self.predictor.s2: state_vec, + self.predictor.asample: action_1hot}) + return predictionloss + + def reward(self, s1, s2, asample): + error = self.sess2.run(self.predictor.forwardloss, + {self.predictor.s1: [s1], self.predictor.s2: [s2], self.predictor.asample: [asample]}) + return error + + def inv_loss(self, s1, s2, asample): + predloss, invloss = self.sess2.run([self.predloss, self.predictor.invloss], + {self.predictor.s1: [s1], self.predictor.s2: [s2], self.predictor.asample: [asample]}) + return predloss, invloss + + def predictedstate(self, s1, s2, asample): + pred, orig = self.sess2.run([self.predictor.predstate, self.predictor.origstate], + {self.predictor.s1: [s1], self.predictor.s2: [s2], + self.predictor.asample: [asample]}) + return pred, orig + + def load_curiosity(self, load_filename): + self.saver.restore(self.sess2, load_filename) + print('Curiosity model has successfully loaded.') + + def save_ICM(self, save_filename): + self.saver.save(self.sess2, save_filename) + print('Curiosity model saved.') \ No newline at end of file diff --git a/curiosity/model_prediction_curiosity.py b/curiosity/model_prediction_curiosity.py new file mode 100644 index 0000000..0126eab --- /dev/null +++ b/curiosity/model_prediction_curiosity.py @@ -0,0 +1,162 @@ +'''Copyright (c) 2017 Deepak Pathak +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +-------------------------------------------------------------------------------- +Original openai License: +-------------------------------------------------------------------------------- +MIT License + +Copyright (c) 2016 openai + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE.''' + +############################################################################### +# adapted from: +# Deepak Pathak, Pulkit Agrawal, Alexei A. Efros, Trevor Darrell +# University of California, Berkeley +# Curiosity-driven Exploration by Self-supervised Prediction + +# added by Paula +############################################################################### + + +import numpy as np +import tensorflow as tf +import tensorflow.contrib.rnn as rnn +import os + + +def normalized_columns_initializer(std=1.0): + def _initializer(shape, dtype=None, partition_info=None): + out = np.random.randn(*shape).astype(np.float32) + out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + return tf.constant(out) + return _initializer + + +def flatten(x): + return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) + + +def cosineLoss(A, B, name): + ''' A, B : (BatchSize, d) ''' + dotprod = tf.reduce_sum(tf.multiply(tf.nn.l2_normalize(A, 1), tf.nn.l2_normalize(B, 1)), 1) + loss = 1-tf.reduce_mean(dotprod, name=name) + return loss + + +def linear(x, size, name, initializer=None, bias_init=0): + with tf.variable_scope(name, reuse=tf.AUTO_REUSE): + w = tf.get_variable("w", [x.get_shape()[1], size], initializer=initializer) # error in second turn, reuse variable? + b = tf.get_variable("b", [size], initializer=tf.constant_initializer(bias_init)) #changed from: name+ "/b" + # initialized now to fix error + return tf.matmul(x, w) + b + + +def inverse_pydialHead(x, final_shape): # does nothing so far! + ''' + input: [None, 268]; output: [None, 1, 268]; + ''' + # print('Using inverse-pydial head design') + # bs = tf.shape(x)[0] + # print(x) + return x + + +def pydialHead(x, layer2): # todo: 200 is size of layer, make it var and enable easy change such as feat_size + ''' + input: [None, 1, 268]; output: [None, ?]; + ''' + x = tf.nn.elu(linear(x, 200, 'fc', normalized_columns_initializer(0.01))) + # print(x.get_shape()) + # x = flatten(x) + # print(x.get_shape()) + # x = tf.nn.elu(x) + return x + + +class StateActionPredictor(object): + def __init__(self, ob_space, ac_space, designHead='pydial', feature_size=200, layer2=200): + # input: s1,s2: : [None, h, w, ch] (usually ch=1 or 4) /pydial: [None, size] + # asample: 1-hot encoding of sampled action from policy: [None, ac_space] + + self.layer2 = layer2 + if designHead == 'pydial': + input_shape = [None, ob_space] + else: + input_shape = [None] + list(ob_space) + + self.s1 = phi1 = tf.placeholder(tf.float32, input_shape) + self.s2 = phi2 = tf.placeholder(tf.float32, input_shape) + self.asample = asample = tf.placeholder(tf.float32, [None, ac_space]) + + # feature encoding: phi1, phi2: [None, LEN] + size = feature_size # 268 for full believstate + if designHead == 'pydial': + phi1 = pydialHead(phi1, self.layer2) + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + phi2 = pydialHead(phi2, self.layer2) + else: + print('So far "pydial" is the only available design head. Please check your configurations.') + + # inverse model: g(phi1,phi2) -> a_inv: [None, ac_space] + g = tf.concat([phi1, phi2], 1) # changed place of 1 + g = tf.nn.relu(linear(g, size, "g1", normalized_columns_initializer(0.01))) + aindex = tf.argmax(asample, axis=1) # aindex: [batch_size,] + logits = linear(g, ac_space, "glast", normalized_columns_initializer(0.01)) + self.invloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=aindex), name="invloss") + self.ainvprobs = tf.nn.softmax(logits, axis=-1) + + # forward model: f(phi1,asample) -> phi2 + # Note: no backprop to asample of policy: it is treated as fixed for predictor training + f = tf.concat([phi1, asample], 1) + f = tf.nn.relu(linear(f, size, "f1", normalized_columns_initializer(0.01))) + f = linear(f, phi1.get_shape()[1].value, "flast", normalized_columns_initializer(0.01)) + # self.forwardloss = 0.5 * tf.reduce_mean(tf.square(tf.subtract(f, phi2)), name='forwardloss') + + # self.forwardloss = 0.5 * tf.reduce_mean(tf.sqrt(tf.abs(tf.subtract(f, phi2))), name='forwardloss') + self.forwardloss = cosineLoss(f, phi2, name='forwardloss') + # self.forwardloss = self.forwardloss * 268.0 # lenFeatures=268. Factored out to make hyperparams not depend on it. + + # prediction and original + self.predstate = f + self.origstate = phi2 + diff --git a/curiosity/pretraining.py b/curiosity/pretraining.py new file mode 100644 index 0000000..f6cd658 --- /dev/null +++ b/curiosity/pretraining.py @@ -0,0 +1,156 @@ +############################################################################### +# PyDial: Multi-domain Statistical Spoken Dialogue System Software +############################################################################### +# +# Copyright 2015 - 2019 +# Cambridge University Engineering Department Dialogue Systems Group +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +############################################################################### + + +''' + This script is to pre-train a belief-state prediction model. + This model then can be used in order to use belief-state prediction error as curiosity rewards. +''' + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import os +import tensorflow as tf + +from curiosity import model_prediction_curiosity as mpc + +# settings **before running: make sure model_name is specified to not accidentally overwrite data** +model_name = '' # name for new model: specify! +num_actions = 16 +num_belief_states = 268 +num_iterations = 3 +learning_rate = 0.001 +forward_loss_wt = 0.2 +feature_size = 200 +# file names pre-trg data: fill out! +action_pre_trg = '' +state_pre_trg = '' +prevstate_pre_trg = '' + + +# read actions and turns from pretrg data file +def read_data1(filename): + sys_act = [] + turn = [] + with open(filename, 'r') as d: + for line in d: + info = line.split(' ') + turn.append(int(info[1])) + sys_act.append(int(info[3])) + return turn, sys_act + + +# read state and prev_state from pretrg data file +def read_data2(filename_ps, filename_s): + state = [] + prev_state = [] + with open(filename_ps, 'r') as d: + for line in d: + info = line.split(' ') + prev_state.append(info) + with open(filename_s, 'r') as d2: + for line in d2: + info = line.split(' ') + state.append(info) + return state, prev_state + + +def unison_shuffled_copies(vec1, vec2, vec3, vec4): + assert len(vec1) == len(vec4) + p = np.random.permutation(len(vec1)) + return np.array(vec1)[p], np.array(vec2)[p], np.array(vec3)[p], np.array(vec4)[p] + + +with tf.variable_scope('curiosity'): + predictor = mpc.StateActionPredictor(num_belief_states, num_actions, designHead='pydial', feature_size=feature_size) + predloss = predictor.invloss * (1 - forward_loss_wt) + predictor.forwardloss * forward_loss_wt + +optimizer = tf.train.AdamOptimizer(learning_rate) +optimize = optimizer.minimize(predloss) + +sess = tf.Session() +sess.run(tf.global_variables_initializer()) +saver = tf.train.Saver() + +# read data from files +t, a = read_data1(action_pre_trg) +a = np.eye(16, 16)[a] # convert to one-hot +s, _s = read_data2(prevstate_pre_trg, state_pre_trg) + +# shuffle vectors +t, a, s, _s = unison_shuffled_copies(t, a, s, _s) + +batch_num = len(t)/64 + +# initialize +loss = [] +inverseloss = [] +forwardloss = [] + +# # check if prev state vec is correct: +# if s[:-1] == _s[1:]: +# print 'works' + +if not os.path.exists('_curiosity_model/pretrg_model/'): + os.mkdir('_curiosity_model/pretrg_model/') + +# # uncomment to train pre-trained model further +# saver.restore(sess, "_curiosity_model/pretrg_model/" + model_name) +# print("Successfully loaded:_curiosity_model/pretrg_model/" + model_name) + +for i in range(num_iterations): + for batch in range(batch_num): + # select batch for trg + prev_state_vec = _s[batch * 64:(batch + 1) * 64] + state_vec = s[batch * 64:(batch + 1) * 64] + action_1hot = a[batch * 64:(batch + 1) * 64] + _, predictionloss, forloss, invloss = sess.run([optimize, predloss, predictor.forwardloss, predictor.invloss], + feed_dict={predictor.s1: prev_state_vec, predictor.s2: state_vec, + predictor.asample: action_1hot}) + # if batch % 5 == 0: + # print predictionloss + loss.append(predictionloss) + inverseloss.append(invloss) + forwardloss.append(forloss) + + t, a, s, _s = unison_shuffled_copies(t, a, s, _s) # shuffle vectors + +saver.save(sess, '_curiosity_model/pretrg_model/trained_curiosity_' + model_name + '_' + str(feature_size)) + +plt.plot(loss, label='prediction_loss') +plt.plot(inverseloss, label='inverse_loss') +plt.plot(forwardloss, label='forward_loss') +plt.legend() +plt.ylabel('Prediction error/ Loss') +plt.xlabel('number of batches') +plt.savefig('_plots/pretraining_loss_' + model_name + '_' + str(feature_size) + '.png', bbox_inches='tight') + + +# # uncomment if needed for experiments +# def curiosity_reward(s1, s2, asample): +# error = sess.run(predictor.forwardloss, +# {predictor.s1: [s1], predictor.s2: [s2], predictor.asample: [asample]}) +# return error +# +# bonus = curiosity_reward(_s[13],s[13],action[13]) +# print(bonus) -- GitLab