Skip to content
Snippets Groups Projects
Commit ca5d4f6e authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Add missing curiosity module

parent 82d071ea
No related branches found
No related tags found
No related merge requests found
Pipeline #40108 passed
RL with Curiosity Rewards for DM
by Paula
==================================================================
The curiosity reward option enables to use belief-state prediction error as an additional reward for
policy learning via RL.
Following files are affected by the curiosity reward option:
1) pydial.py
2) utils/Settings.py
3) evaluation/SuccessEvaluator.py
4) policy/ACERpolicy.py or DQNpolicy.py (other policies do not include curiosity model training option yet)
additional files (included in the curiosity directory):
5) model_prediction_curiosity.py
6) curiosity_module.py
7) pretraining.py
Configuration:
The curiosityreward option can be chosen in the evaluation section in the configuration file.
Curiosity rewards are used when curiosityreward = True. This option overwrites any epsilon-greedy exploration settings
and the policy is used greedily at all times with no random exploration.
Feature size belief-state feature encoding can be chosen as feat_size (by default it is set to 200 if not specified in
the configuration file). #Todo: option to choose layer2 size?
The name of the pre-training model to be used is to be specified under the variable model_name in the config file.
Always make sure that the pre-traing model has the same feature size as the model to be trained,
else dimensions will not fit when reading in the model.
Reward scale is default set to be 1, which is determined to work best for environment 3 and 4.
The reward scale can be set using variable rew_scale in the configuration.
Example from configuration file:
###### Evaluation parameters ######
[eval]
rewardvenuerecommended = 0
penaliseallturns = True
wrongvenuepenalty = 0
notmentionedvaluepenalty = 0
successmeasure = objective
successreward = 20
curiosityreward = True
feat_size = 200
rew_scale = 1
model_name = trained_curiosityacer-shuffle22_feat200
pre_trg = False
Pre-training:
1) Pre-training data collection is done by setting pre_trg = True in the config file.
Note for pre-training data collection: Only about 100 dialogues are needed for pre-training,
no need to run training for longer.
The dialogues for pre-training should be simple and do not have to match the policy used later on.
Further pre-training dialogues do not have to be successful for pre-training.
2) To run pre-training and build an initial curiosity model the script pretraining.py is used.
Settings for pre-traing are changed in the script directly. Before running make sure pre-training data file names
and model_name are specified.
After running the new model can be used, specifying model_name in the config file for training new policies.
###############################################################################
# idea adapted from:
# Deepak Pathak, Pulkit Agrawal, Alexei A. Efros, Trevor Darrell
# University of California, Berkeley
# Curiosity-driven Exploration by Self-supervised Prediction
# added by Paula
###############################################################################
import numpy as np
import os
import tensorflow as tf
from curiosity import model_prediction_curiosity as mpc
from utils import Settings
class Curious(object):
def __init__(self):
tf.reset_default_graph()
self.learning_rate = 0.001
self.forward_loss_wt = 0.2
self.feat_size = 200
self.num_actions = 16
self.num_belief_states = 268
self.layer2 = 200
if Settings.config.has_option("eval", "feat_size"):
self.feat_size = Settings.config.getint("eval", "feat_size")
with tf.variable_scope('curiosity', reuse=tf.AUTO_REUSE):
self.predictor = mpc.StateActionPredictor(self.num_belief_states, self.num_actions,
feature_size=self.feat_size, layer2=self.layer2)
self.predloss = self.predictor.invloss * (1 - self.forward_loss_wt) + \
self.predictor.forwardloss * self.forward_loss_wt
self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.optimize = self.optimizer.minimize(self.predloss)
# self.optimize = self.optimizer.minimize(self.predictor.forwardloss) # when no feature encoding is used!
self.cnt = 1
self.sess2 = tf.Session()
self.sess2.run(tf.global_variables_initializer())
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
self.saver = tf.train.Saver(var_list=[v for v in all_variables if "Variab" not in v.name and "beta" not in v.name])
def training(self, state_vec, prev_state_vec, action_1hot):
_, predictionloss = self.sess2.run([self.optimize, self.predloss],
feed_dict={self.predictor.s1: prev_state_vec,
self.predictor.s2: state_vec,
self.predictor.asample: action_1hot})
return predictionloss
def reward(self, s1, s2, asample):
error = self.sess2.run(self.predictor.forwardloss,
{self.predictor.s1: [s1], self.predictor.s2: [s2], self.predictor.asample: [asample]})
return error
def inv_loss(self, s1, s2, asample):
predloss, invloss = self.sess2.run([self.predloss, self.predictor.invloss],
{self.predictor.s1: [s1], self.predictor.s2: [s2], self.predictor.asample: [asample]})
return predloss, invloss
def predictedstate(self, s1, s2, asample):
pred, orig = self.sess2.run([self.predictor.predstate, self.predictor.origstate],
{self.predictor.s1: [s1], self.predictor.s2: [s2],
self.predictor.asample: [asample]})
return pred, orig
def load_curiosity(self, load_filename):
self.saver.restore(self.sess2, load_filename)
print('Curiosity model has successfully loaded.')
def save_ICM(self, save_filename):
self.saver.save(self.sess2, save_filename)
print('Curiosity model saved.')
\ No newline at end of file
'''Copyright (c) 2017 Deepak Pathak
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
Original openai License:
--------------------------------------------------------------------------------
MIT License
Copyright (c) 2016 openai
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''
###############################################################################
# adapted from:
# Deepak Pathak, Pulkit Agrawal, Alexei A. Efros, Trevor Darrell
# University of California, Berkeley
# Curiosity-driven Exploration by Self-supervised Prediction
# added by Paula
###############################################################################
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
import os
def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def flatten(x):
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
def cosineLoss(A, B, name):
''' A, B : (BatchSize, d) '''
dotprod = tf.reduce_sum(tf.multiply(tf.nn.l2_normalize(A, 1), tf.nn.l2_normalize(B, 1)), 1)
loss = 1-tf.reduce_mean(dotprod, name=name)
return loss
def linear(x, size, name, initializer=None, bias_init=0):
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
w = tf.get_variable("w", [x.get_shape()[1], size], initializer=initializer) # error in second turn, reuse variable?
b = tf.get_variable("b", [size], initializer=tf.constant_initializer(bias_init)) #changed from: name+ "/b"
# initialized now to fix error
return tf.matmul(x, w) + b
def inverse_pydialHead(x, final_shape): # does nothing so far!
'''
input: [None, 268]; output: [None, 1, 268];
'''
# print('Using inverse-pydial head design')
# bs = tf.shape(x)[0]
# print(x)
return x
def pydialHead(x, layer2): # todo: 200 is size of layer, make it var and enable easy change such as feat_size
'''
input: [None, 1, 268]; output: [None, ?];
'''
x = tf.nn.elu(linear(x, 200, 'fc', normalized_columns_initializer(0.01)))
# print(x.get_shape())
# x = flatten(x)
# print(x.get_shape())
# x = tf.nn.elu(x)
return x
class StateActionPredictor(object):
def __init__(self, ob_space, ac_space, designHead='pydial', feature_size=200, layer2=200):
# input: s1,s2: : [None, h, w, ch] (usually ch=1 or 4) /pydial: [None, size]
# asample: 1-hot encoding of sampled action from policy: [None, ac_space]
self.layer2 = layer2
if designHead == 'pydial':
input_shape = [None, ob_space]
else:
input_shape = [None] + list(ob_space)
self.s1 = phi1 = tf.placeholder(tf.float32, input_shape)
self.s2 = phi2 = tf.placeholder(tf.float32, input_shape)
self.asample = asample = tf.placeholder(tf.float32, [None, ac_space])
# feature encoding: phi1, phi2: [None, LEN]
size = feature_size # 268 for full believstate
if designHead == 'pydial':
phi1 = pydialHead(phi1, self.layer2)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
phi2 = pydialHead(phi2, self.layer2)
else:
print('So far "pydial" is the only available design head. Please check your configurations.')
# inverse model: g(phi1,phi2) -> a_inv: [None, ac_space]
g = tf.concat([phi1, phi2], 1) # changed place of 1
g = tf.nn.relu(linear(g, size, "g1", normalized_columns_initializer(0.01)))
aindex = tf.argmax(asample, axis=1) # aindex: [batch_size,]
logits = linear(g, ac_space, "glast", normalized_columns_initializer(0.01))
self.invloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=aindex), name="invloss")
self.ainvprobs = tf.nn.softmax(logits, axis=-1)
# forward model: f(phi1,asample) -> phi2
# Note: no backprop to asample of policy: it is treated as fixed for predictor training
f = tf.concat([phi1, asample], 1)
f = tf.nn.relu(linear(f, size, "f1", normalized_columns_initializer(0.01)))
f = linear(f, phi1.get_shape()[1].value, "flast", normalized_columns_initializer(0.01))
# self.forwardloss = 0.5 * tf.reduce_mean(tf.square(tf.subtract(f, phi2)), name='forwardloss')
# self.forwardloss = 0.5 * tf.reduce_mean(tf.sqrt(tf.abs(tf.subtract(f, phi2))), name='forwardloss')
self.forwardloss = cosineLoss(f, phi2, name='forwardloss')
# self.forwardloss = self.forwardloss * 268.0 # lenFeatures=268. Factored out to make hyperparams not depend on it.
# prediction and original
self.predstate = f
self.origstate = phi2
###############################################################################
# PyDial: Multi-domain Statistical Spoken Dialogue System Software
###############################################################################
#
# Copyright 2015 - 2019
# Cambridge University Engineering Department Dialogue Systems Group
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
###############################################################################
'''
This script is to pre-train a belief-state prediction model.
This model then can be used in order to use belief-state prediction error as curiosity rewards.
'''
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from curiosity import model_prediction_curiosity as mpc
# settings **before running: make sure model_name is specified to not accidentally overwrite data**
model_name = '' # name for new model: specify!
num_actions = 16
num_belief_states = 268
num_iterations = 3
learning_rate = 0.001
forward_loss_wt = 0.2
feature_size = 200
# file names pre-trg data: fill out!
action_pre_trg = ''
state_pre_trg = ''
prevstate_pre_trg = ''
# read actions and turns from pretrg data file
def read_data1(filename):
sys_act = []
turn = []
with open(filename, 'r') as d:
for line in d:
info = line.split(' ')
turn.append(int(info[1]))
sys_act.append(int(info[3]))
return turn, sys_act
# read state and prev_state from pretrg data file
def read_data2(filename_ps, filename_s):
state = []
prev_state = []
with open(filename_ps, 'r') as d:
for line in d:
info = line.split(' ')
prev_state.append(info)
with open(filename_s, 'r') as d2:
for line in d2:
info = line.split(' ')
state.append(info)
return state, prev_state
def unison_shuffled_copies(vec1, vec2, vec3, vec4):
assert len(vec1) == len(vec4)
p = np.random.permutation(len(vec1))
return np.array(vec1)[p], np.array(vec2)[p], np.array(vec3)[p], np.array(vec4)[p]
with tf.variable_scope('curiosity'):
predictor = mpc.StateActionPredictor(num_belief_states, num_actions, designHead='pydial', feature_size=feature_size)
predloss = predictor.invloss * (1 - forward_loss_wt) + predictor.forwardloss * forward_loss_wt
optimizer = tf.train.AdamOptimizer(learning_rate)
optimize = optimizer.minimize(predloss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# read data from files
t, a = read_data1(action_pre_trg)
a = np.eye(16, 16)[a] # convert to one-hot
s, _s = read_data2(prevstate_pre_trg, state_pre_trg)
# shuffle vectors
t, a, s, _s = unison_shuffled_copies(t, a, s, _s)
batch_num = len(t)/64
# initialize
loss = []
inverseloss = []
forwardloss = []
# # check if prev state vec is correct:
# if s[:-1] == _s[1:]:
# print 'works'
if not os.path.exists('_curiosity_model/pretrg_model/'):
os.mkdir('_curiosity_model/pretrg_model/')
# # uncomment to train pre-trained model further
# saver.restore(sess, "_curiosity_model/pretrg_model/" + model_name)
# print("Successfully loaded:_curiosity_model/pretrg_model/" + model_name)
for i in range(num_iterations):
for batch in range(batch_num):
# select batch for trg
prev_state_vec = _s[batch * 64:(batch + 1) * 64]
state_vec = s[batch * 64:(batch + 1) * 64]
action_1hot = a[batch * 64:(batch + 1) * 64]
_, predictionloss, forloss, invloss = sess.run([optimize, predloss, predictor.forwardloss, predictor.invloss],
feed_dict={predictor.s1: prev_state_vec, predictor.s2: state_vec,
predictor.asample: action_1hot})
# if batch % 5 == 0:
# print predictionloss
loss.append(predictionloss)
inverseloss.append(invloss)
forwardloss.append(forloss)
t, a, s, _s = unison_shuffled_copies(t, a, s, _s) # shuffle vectors
saver.save(sess, '_curiosity_model/pretrg_model/trained_curiosity_' + model_name + '_' + str(feature_size))
plt.plot(loss, label='prediction_loss')
plt.plot(inverseloss, label='inverse_loss')
plt.plot(forwardloss, label='forward_loss')
plt.legend()
plt.ylabel('Prediction error/ Loss')
plt.xlabel('number of batches')
plt.savefig('_plots/pretraining_loss_' + model_name + '_' + str(feature_size) + '.png', bbox_inches='tight')
# # uncomment if needed for experiments
# def curiosity_reward(s1, s2, asample):
# error = sess.run(predictor.forwardloss,
# {predictor.s1: [s1], predictor.s2: [s2], predictor.asample: [asample]})
# return error
#
# bonus = curiosity_reward(_s[13],s[13],action[13])
# print(bonus)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment