Skip to content
Snippets Groups Projects
Commit 7e0da952 authored by Chazzz's avatar Chazzz
Browse files

Added batch normalization and updated model.zip

parent 35ddc487
Branches
No related tags found
No related merge requests found
No preview for this file type
...@@ -27,6 +27,9 @@ class Model: ...@@ -27,6 +27,9 @@ class Model:
self.mustRestore = mustRestore self.mustRestore = mustRestore
self.snapID = 0 self.snapID = 0
# Whether to use normalization over a batch or a population
self.is_train = tf.placeholder(tf.bool, name="is_train");
# input image batch # input image batch
self.inputImgs = tf.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1])) self.inputImgs = tf.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1]))
...@@ -38,6 +41,8 @@ class Model: ...@@ -38,6 +41,8 @@ class Model:
# setup optimizer to train NN # setup optimizer to train NN
self.batchesTrained = 0 self.batchesTrained = 0
self.learningRate = tf.placeholder(tf.float32, shape=[]) self.learningRate = tf.placeholder(tf.float32, shape=[])
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops):
self.optimizer = tf.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) self.optimizer = tf.train.RMSPropOptimizer(self.learningRate).minimize(self.loss)
# initialize TF # initialize TF
...@@ -59,7 +64,8 @@ class Model: ...@@ -59,7 +64,8 @@ class Model:
for i in range(numLayers): for i in range(numLayers):
kernel = tf.Variable(tf.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1)) kernel = tf.Variable(tf.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1))
conv = tf.nn.conv2d(pool, kernel, padding='SAME', strides=(1,1,1,1)) conv = tf.nn.conv2d(pool, kernel, padding='SAME', strides=(1,1,1,1))
relu = tf.nn.relu(conv) conv_norm = tf.layers.batch_normalization(conv, training=self.is_train)
relu = tf.nn.relu(conv_norm)
pool = tf.nn.max_pool(relu, (1, poolVals[i][0], poolVals[i][1], 1), (1, strideVals[i][0], strideVals[i][1], 1), 'VALID') pool = tf.nn.max_pool(relu, (1, poolVals[i][0], poolVals[i][1], 1), (1, strideVals[i][0], strideVals[i][1], 1), 'VALID')
self.cnnOut4d = pool self.cnnOut4d = pool
...@@ -205,7 +211,7 @@ class Model: ...@@ -205,7 +211,7 @@ class Model:
sparse = self.toSparse(batch.gtTexts) sparse = self.toSparse(batch.gtTexts)
rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
evalList = [self.optimizer, self.loss] evalList = [self.optimizer, self.loss]
feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate} feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate, self.is_train: True}
(_, lossVal) = self.sess.run(evalList, feedDict) (_, lossVal) = self.sess.run(evalList, feedDict)
self.batchesTrained += 1 self.batchesTrained += 1
return lossVal return lossVal
...@@ -217,7 +223,7 @@ class Model: ...@@ -217,7 +223,7 @@ class Model:
# decode, optionally save RNN output # decode, optionally save RNN output
numBatchElements = len(batch.imgs) numBatchElements = len(batch.imgs)
evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else []) evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else [])
feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements} feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict) evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict)
decoded = evalRes[0] decoded = evalRes[0]
texts = self.decoderOutputToText(decoded, numBatchElements) texts = self.decoderOutputToText(decoded, numBatchElements)
...@@ -228,7 +234,7 @@ class Model: ...@@ -228,7 +234,7 @@ class Model:
sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts) sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts)
ctcInput = evalRes[1] ctcInput = evalRes[1]
evalList = self.lossPerElement evalList = self.lossPerElement
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements} feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
lossVals = self.sess.run(evalList, feedDict) lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals) probs = np.exp(-lossVals)
return (texts, probs) return (texts, probs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment