Skip to content
Snippets Groups Projects
Commit 2328e221 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

Merge branch 'Chazzz-master'

parents 35ddc487 b56ed6b4
No related branches found
No related tags found
No related merge requests found
No preview for this file type
...@@ -27,6 +27,9 @@ class Model: ...@@ -27,6 +27,9 @@ class Model:
self.mustRestore = mustRestore self.mustRestore = mustRestore
self.snapID = 0 self.snapID = 0
# Whether to use normalization over a batch or a population
self.is_train = tf.placeholder(tf.bool, name="is_train");
# input image batch # input image batch
self.inputImgs = tf.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1])) self.inputImgs = tf.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1]))
...@@ -38,6 +41,8 @@ class Model: ...@@ -38,6 +41,8 @@ class Model:
# setup optimizer to train NN # setup optimizer to train NN
self.batchesTrained = 0 self.batchesTrained = 0
self.learningRate = tf.placeholder(tf.float32, shape=[]) self.learningRate = tf.placeholder(tf.float32, shape=[])
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops):
self.optimizer = tf.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) self.optimizer = tf.train.RMSPropOptimizer(self.learningRate).minimize(self.loss)
# initialize TF # initialize TF
...@@ -59,7 +64,8 @@ class Model: ...@@ -59,7 +64,8 @@ class Model:
for i in range(numLayers): for i in range(numLayers):
kernel = tf.Variable(tf.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1)) kernel = tf.Variable(tf.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1))
conv = tf.nn.conv2d(pool, kernel, padding='SAME', strides=(1,1,1,1)) conv = tf.nn.conv2d(pool, kernel, padding='SAME', strides=(1,1,1,1))
relu = tf.nn.relu(conv) conv_norm = tf.layers.batch_normalization(conv, training=self.is_train)
relu = tf.nn.relu(conv_norm)
pool = tf.nn.max_pool(relu, (1, poolVals[i][0], poolVals[i][1], 1), (1, strideVals[i][0], strideVals[i][1], 1), 'VALID') pool = tf.nn.max_pool(relu, (1, poolVals[i][0], poolVals[i][1], 1), (1, strideVals[i][0], strideVals[i][1], 1), 'VALID')
self.cnnOut4d = pool self.cnnOut4d = pool
...@@ -205,19 +211,19 @@ class Model: ...@@ -205,19 +211,19 @@ class Model:
sparse = self.toSparse(batch.gtTexts) sparse = self.toSparse(batch.gtTexts)
rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
evalList = [self.optimizer, self.loss] evalList = [self.optimizer, self.loss]
feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate} feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate, self.is_train: True}
(_, lossVal) = self.sess.run(evalList, feedDict) (_, lossVal) = self.sess.run(evalList, feedDict)
self.batchesTrained += 1 self.batchesTrained += 1
return lossVal return lossVal
def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False): def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False):
"feed a batch into the NN to recngnize the texts" "feed a batch into the NN to recognize the texts"
# decode, optionally save RNN output # decode, optionally save RNN output
numBatchElements = len(batch.imgs) numBatchElements = len(batch.imgs)
evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else []) evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else [])
feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements} feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict) evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict)
decoded = evalRes[0] decoded = evalRes[0]
texts = self.decoderOutputToText(decoded, numBatchElements) texts = self.decoderOutputToText(decoded, numBatchElements)
...@@ -228,7 +234,7 @@ class Model: ...@@ -228,7 +234,7 @@ class Model:
sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts) sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts)
ctcInput = evalRes[1] ctcInput = evalRes[1]
evalList = self.lossPerElement evalList = self.lossPerElement
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements} feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
lossVals = self.sess.run(evalList, feedDict) lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals) probs = np.exp(-lossVals)
return (texts, probs) return (texts, probs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment