Skip to content
Snippets Groups Projects
Commit f869e085 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

dynamic batch size

parent cf474341
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ class Model:
self.snapID = 0
# input image batch
self.inputImgs = tf.placeholder(tf.float32, shape=(Model.batchSize, Model.imgSize[0], Model.imgSize[1]))
self.inputImgs = tf.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1]))
# setup CNN, RNN and CTC
self.setupCNN()
......@@ -100,7 +100,7 @@ class Model:
self.loss = tf.reduce_mean(tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True))
# calc loss for each element to compute label probability
self.savedCtcInput = tf.placeholder(tf.float32, shape=[Model.maxTextLen, Model.batchSize, len(self.charList) + 1])
self.savedCtcInput = tf.placeholder(tf.float32, shape=[Model.maxTextLen, None, len(self.charList) + 1])
self.lossPerElement = tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, sequence_length=self.seqLen, ctc_merge_repeated=True)
# decoder: either best path decoding or beam search decoding
......@@ -168,16 +168,16 @@ class Model:
return (indices, values, shape)
def decoderOutputToText(self, ctcOutput):
def decoderOutputToText(self, ctcOutput, batchSize):
"extract texts from output of CTC decoder"
# contains string of labels for each batch element
encodedLabelStrs = [[] for i in range(Model.batchSize)]
encodedLabelStrs = [[] for i in range(batchSize)]
# word beam search: label strings terminated by blank
if self.decoderType == DecoderType.WordBeamSearch:
blank=len(self.charList)
for b in range(Model.batchSize):
for b in range(batchSize):
for label in ctcOutput[b]:
if label==blank:
break
......@@ -189,7 +189,7 @@ class Model:
decoded=ctcOutput[0][0]
# go over all indices and save mapping: batch -> values
idxDict = { b : [] for b in range(Model.batchSize) }
idxDict = { b : [] for b in range(batchSize) }
for (idx, idx2d) in enumerate(decoded.indices):
label = decoded.values[idx]
batchElement = idx2d[0] # index according to [b,t]
......@@ -201,9 +201,12 @@ class Model:
def trainBatch(self, batch):
"feed a batch into the NN to train it"
numBatchElements = len(batch.imgs)
sparse = self.toSparse(batch.gtTexts)
rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
(_, lossVal) = self.sess.run([self.optimizer, self.loss], { self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * Model.batchSize, self.learningRate : rate} )
evalList = [self.optimizer, self.loss]
feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate}
(_, lossVal) = self.sess.run(evalList, feedDict)
self.batchesTrained += 1
return lossVal
......@@ -212,17 +215,21 @@ class Model:
"feed a batch into the NN to recngnize the texts"
# decode, optionally save RNN output
numBatchElements = len(batch.imgs)
evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else [])
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], { self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * Model.batchSize } )
feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements}
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict)
decoded = evalRes[0]
texts = self.decoderOutputToText(decoded)
texts = self.decoderOutputToText(decoded, numBatchElements)
# feed RNN output and recognized text into CTC loss to compute labeling probability
probs = None
if calcProbability:
sparse = self.toSparse(texts)
ctcInput = evalRes[1]
lossVals = self.sess.run(self.lossPerElement, { self.savedCtcInput : ctcInput, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * Model.batchSize} )
evalList = self.lossPerElement
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements}
lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals)
return (texts, probs)
......
......@@ -91,10 +91,10 @@ def validate(model, loader):
def infer(model, fnImg):
"recognize text in image provided by file path"
img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize)
batch = Batch(None, [img] * Model.batchSize) # fill all batch elements with same input image
(recognized, probability) = model.inferBatch(batch, True) # recognize text
print('Recognized:', '"' + recognized[0] + '"') # all batch elements hold same result
print('Probability:', probability[0]) # all batch elements hold same result
batch = Batch(None, [img])
(recognized, probability) = model.inferBatch(batch, True)
print('Recognized:', '"' + recognized[0] + '"')
print('Probability:', probability[0])
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment