Skip to content
Snippets Groups Projects
Commit c1ef4818 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

calc probability for recognized text

parent 192d42c0
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@ The input image and the expected output is shown below.
Validation character error rate of saved model: 13.956289%
Init with stored values from ../model/snapshot-32
Recognized: "little"
Probability: 0.86143184
```
Tested with:
......
......@@ -87,6 +87,7 @@ class DataLoader:
# list of all chars in dataset
self.charList = sorted(list(chars))
def truncateLabel(self, text, maxTextLen):
# ctc_loss can't compute loss if it cannot find a mapping between text label and input
# labels. Repeat letters cost double because of the blank symbol needing to be inserted.
......@@ -101,6 +102,7 @@ class DataLoader:
return text[:i]
return text
def trainSet(self):
"switch to randomly chosen subset of training set"
self.dataAugmentation = True
......
......@@ -2,6 +2,7 @@ from __future__ import division
from __future__ import print_function
import sys
import numpy as np
import tensorflow as tf
......@@ -34,7 +35,7 @@ class Model:
rnnOut3d = self.setupRNN(cnnOut4d)
# CTC
(self.loss, self.decoder) = self.setupCTC(rnnOut3d)
(self.loss, self.lossPerElement, self.decoder) = self.setupCTC(rnnOut3d)
# optimizer for NN parameters
self.batchesTrained = 0
......@@ -92,17 +93,23 @@ class Model:
def setupCTC(self, ctcIn3d):
"create CTC loss and decoder and return them"
# BxTxC -> TxBxC
ctcIn3dTBC = tf.transpose(ctcIn3d, [1, 0, 2])
self.ctcIn3dTBC = tf.transpose(ctcIn3d, [1, 0, 2])
# ground truth text as sparse tensor
self.gtTexts = tf.SparseTensor(tf.placeholder(tf.int64, shape=[None, 2]) , tf.placeholder(tf.int32, [None]), tf.placeholder(tf.int64, [2]))
# calc loss for batch
self.seqLen = tf.placeholder(tf.int32, [None])
loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)
loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)
# calc loss for each element to compute label probability
self.savedCtcInput = tf.placeholder(tf.float32, shape=[Model.maxTextLen, Model.batchSize, len(self.charList) + 1])
lossPerElement = tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, sequence_length=self.seqLen, ctc_merge_repeated=True)
# decoder: either best path decoding or beam search decoding
if self.decoderType == DecoderType.BestPath:
decoder = tf.nn.ctc_greedy_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen)
decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen)
elif self.decoderType == DecoderType.BeamSearch:
decoder = tf.nn.ctc_beam_search_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=False)
decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=False)
elif self.decoderType == DecoderType.WordBeamSearch:
# import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch)
word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so')
......@@ -116,7 +123,7 @@ class Model:
decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(ctcIn3dTBC, dim=2), 50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))
# return a CTC operation to compute the loss and a CTC operation to decode the RNN output
return (tf.reduce_mean(loss), decoder)
return (tf.reduce_mean(loss), lossPerElement, decoder)
def setupTF(self):
......@@ -206,10 +213,23 @@ class Model:
return lossVal
def inferBatch(self, batch):
def inferBatch(self, batch, calcProbability=False):
"feed a batch into the NN to recngnize the texts"
decoded = self.sess.run(self.decoder, { self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * Model.batchSize } )
return self.decoderOutputToText(decoded)
# decode, optionally save RNN output
evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else [])
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], { self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * Model.batchSize } )
decoded = evalRes[0]
texts = self.decoderOutputToText(decoded)
# feed RNN output and recognized text into CTC to compute labeling probability
probs = None
if calcProbability:
sparse = self.toSparse(texts)
ctcInput = evalRes[1]
lossVals = self.sess.run(self.lossPerElement, { self.savedCtcInput : ctcInput, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * Model.batchSize} )
probs = np.exp(-lossVals)
return (texts, probs)
def save(self):
......
......@@ -70,7 +70,7 @@ def validate(model, loader):
iterInfo = loader.getIteratorInfo()
print('Batch:', iterInfo[0],'/', iterInfo[1])
batch = loader.getNext()
recognized = model.inferBatch(batch)
(recognized, _) = model.inferBatch(batch)
print('Ground truth -> Recognized')
for i in range(len(recognized)):
......@@ -92,8 +92,9 @@ def infer(model, fnImg):
"recognize text in image provided by file path"
img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize)
batch = Batch(None, [img] * Model.batchSize) # fill all batch elements with same input image
recognized = model.inferBatch(batch) # recognize text
(recognized, probability) = model.inferBatch(batch, True) # recognize text
print('Recognized:', '"' + recognized[0] + '"') # all batch elements hold same result
print('Probability:', probability[0]) # all batch elements hold same result
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment