diff --git a/.gitignore b/.gitignore index 632b7256c5e9c184fb9847e48acc71ecce8c8b23..1b9575d156b3cb7320b4dfc73039b72535140fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ data/words/ data/words.txt +data/words.7z src/__pycache__/ model/checkpoint model/snapshot-* diff --git a/README.md b/README.md index 16a15d8415bf5c4f3a7d271b989823f7f8ae3a36..39e0e952a6e60f04b0b375d3d9bf84ce58acf898 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Handwritten Text Recognition with TensorFlow +**Update 2020: code is compatible with TF2** + Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset. This Neural Network (NN) model recognizes the text contained in the images of segmented words as shown in the illustration below. As these word-images are smaller than images of complete text-lines, the NN can be kept small and training on the CPU is feasible. diff --git a/src/DataLoader.py b/src/DataLoader.py index 606b85cc9f7fef8984e60c973529cf5d610d2d5a..002d5711525da4d5d5ee3931b6c4d9d305ef6733 100644 --- a/src/DataLoader.py +++ b/src/DataLoader.py @@ -3,137 +3,136 @@ from __future__ import print_function import os import random -import numpy as np + import cv2 +import numpy as np + from SamplePreprocessor import preprocess class Sample: - "sample from the dataset" - def __init__(self, gtText, filePath): - self.gtText = gtText - self.filePath = filePath + "sample from the dataset" + def __init__(self, gtText, filePath): + self.gtText = gtText + self.filePath = filePath -class Batch: - "batch containing images and ground truth texts" - def __init__(self, gtTexts, imgs): - self.imgs = np.stack(imgs, axis=0) - self.gtTexts = gtTexts +class Batch: + "batch containing images and ground truth texts" -class DataLoader: - "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" - - def __init__(self, filePath, batchSize, imgSize, maxTextLen): - "loader for dataset at given location, preprocess images and text according to parameters" - - assert filePath[-1]=='/' - - self.dataAugmentation = False - self.currIdx = 0 - self.batchSize = batchSize - self.imgSize = imgSize - self.samples = [] - - f=open(filePath+'words.txt') - chars = set() - bad_samples = [] - bad_samples_reference = ['a01-117-05-02.png', 'r06-022-03-05.png'] - for line in f: - # ignore comment line - if not line or line[0]=='#': - continue - - lineSplit = line.strip().split(' ') - assert len(lineSplit) >= 9 - - # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png - fileNameSplit = lineSplit[0].split('-') - fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + lineSplit[0] + '.png' - - # GT text are columns starting at 9 - gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen) - chars = chars.union(set(list(gtText))) - - # check if image is not empty - if not os.path.getsize(fileName): - bad_samples.append(lineSplit[0] + '.png') - continue - - # put sample into list - self.samples.append(Sample(gtText, fileName)) - - # some images in the IAM dataset are known to be damaged, don't show warning for them - if set(bad_samples) != set(bad_samples_reference): - print("Warning, damaged images found:", bad_samples) - print("Damaged images expected:", bad_samples_reference) - - # split into training and validation set: 95% - 5% - splitIdx = int(0.95 * len(self.samples)) - self.trainSamples = self.samples[:splitIdx] - self.validationSamples = self.samples[splitIdx:] - - # put words into lists - self.trainWords = [x.gtText for x in self.trainSamples] - self.validationWords = [x.gtText for x in self.validationSamples] - - # number of randomly chosen samples per epoch for training - self.numTrainSamplesPerEpoch = 25000 - - # start with train set - self.trainSet() - - # list of all chars in dataset - self.charList = sorted(list(chars)) - - - def truncateLabel(self, text, maxTextLen): - # ctc_loss can't compute loss if it cannot find a mapping between text label and input - # labels. Repeat letters cost double because of the blank symbol needing to be inserted. - # If a too-long label is provided, ctc_loss returns an infinite gradient - cost = 0 - for i in range(len(text)): - if i != 0 and text[i] == text[i-1]: - cost += 2 - else: - cost += 1 - if cost > maxTextLen: - return text[:i] - return text - - - def trainSet(self): - "switch to randomly chosen subset of training set" - self.dataAugmentation = True - self.currIdx = 0 - random.shuffle(self.trainSamples) - self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch] - - - def validationSet(self): - "switch to validation set" - self.dataAugmentation = False - self.currIdx = 0 - self.samples = self.validationSamples - - - def getIteratorInfo(self): - "current batch index and overall number of batches" - return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize) - - - def hasNext(self): - "iterator" - return self.currIdx + self.batchSize <= len(self.samples) - - - def getNext(self): - "iterator" - batchRange = range(self.currIdx, self.currIdx + self.batchSize) - gtTexts = [self.samples[i].gtText for i in batchRange] - imgs = [preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation) for i in batchRange] - self.currIdx += self.batchSize - return Batch(gtTexts, imgs) + def __init__(self, gtTexts, imgs): + self.imgs = np.stack(imgs, axis=0) + self.gtTexts = gtTexts +class DataLoader: + "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" + + def __init__(self, filePath, batchSize, imgSize, maxTextLen): + "loader for dataset at given location, preprocess images and text according to parameters" + + assert filePath[-1] == '/' + + self.dataAugmentation = False + self.currIdx = 0 + self.batchSize = batchSize + self.imgSize = imgSize + self.samples = [] + + f = open(filePath + 'words.txt') + chars = set() + bad_samples = [] + bad_samples_reference = ['a01-117-05-02.png', 'r06-022-03-05.png'] + for line in f: + # ignore comment line + if not line or line[0] == '#': + continue + + lineSplit = line.strip().split(' ') + assert len(lineSplit) >= 9 + + # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png + fileNameSplit = lineSplit[0].split('-') + fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + \ + lineSplit[0] + '.png' + + # GT text are columns starting at 9 + gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen) + chars = chars.union(set(list(gtText))) + + # check if image is not empty + if not os.path.getsize(fileName): + bad_samples.append(lineSplit[0] + '.png') + continue + + # put sample into list + self.samples.append(Sample(gtText, fileName)) + + # some images in the IAM dataset are known to be damaged, don't show warning for them + if set(bad_samples) != set(bad_samples_reference): + print("Warning, damaged images found:", bad_samples) + print("Damaged images expected:", bad_samples_reference) + + # split into training and validation set: 95% - 5% + splitIdx = int(0.95 * len(self.samples)) + self.trainSamples = self.samples[:splitIdx] + self.validationSamples = self.samples[splitIdx:] + + # put words into lists + self.trainWords = [x.gtText for x in self.trainSamples] + self.validationWords = [x.gtText for x in self.validationSamples] + + # number of randomly chosen samples per epoch for training + self.numTrainSamplesPerEpoch = 25000 + + # start with train set + self.trainSet() + + # list of all chars in dataset + self.charList = sorted(list(chars)) + + def truncateLabel(self, text, maxTextLen): + # ctc_loss can't compute loss if it cannot find a mapping between text label and input + # labels. Repeat letters cost double because of the blank symbol needing to be inserted. + # If a too-long label is provided, ctc_loss returns an infinite gradient + cost = 0 + for i in range(len(text)): + if i != 0 and text[i] == text[i - 1]: + cost += 2 + else: + cost += 1 + if cost > maxTextLen: + return text[:i] + return text + + def trainSet(self): + "switch to randomly chosen subset of training set" + self.dataAugmentation = True + self.currIdx = 0 + random.shuffle(self.trainSamples) + self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch] + + def validationSet(self): + "switch to validation set" + self.dataAugmentation = False + self.currIdx = 0 + self.samples = self.validationSamples + + def getIteratorInfo(self): + "current batch index and overall number of batches" + return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize) + + def hasNext(self): + "iterator" + return self.currIdx + self.batchSize <= len(self.samples) + + def getNext(self): + "iterator" + batchRange = range(self.currIdx, self.currIdx + self.batchSize) + gtTexts = [self.samples[i].gtText for i in batchRange] + imgs = [ + preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation) + for i in batchRange] + self.currIdx += self.batchSize + return Batch(gtTexts, imgs) diff --git a/src/Model.py b/src/Model.py index e987acf2b0ae2a436e4b6dda8ec69f6b66b6910f..07e300fd880eb93ba401eb05b792d764ae8bd2fc 100644 --- a/src/Model.py +++ b/src/Model.py @@ -1,276 +1,285 @@ from __future__ import division from __future__ import print_function -import sys import numpy as np -import tensorflow as tf import os +import sys +import tensorflow as tf -# Disable eagre +# Disable eager mode tf.compat.v1.disable_eager_execution() + class DecoderType: - BestPath = 0 - BeamSearch = 1 - WordBeamSearch = 2 + BestPath = 0 + BeamSearch = 1 + WordBeamSearch = 2 class Model: - "minimalistic TF model for HTR" - - # model constants - batchSize = 50 - imgSize = (128, 32) - maxTextLen = 32 - - def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False, dump=False): - "init model: add CNN, RNN and CTC and initialize TF" - self.dump = dump - self.charList = charList - self.decoderType = decoderType - self.mustRestore = mustRestore - self.snapID = 0 - - # Whether to use normalization over a batch or a population - self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') - - # input image batch - self.inputImgs = tf.compat.v1.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1])) - - # setup CNN, RNN and CTC - self.setupCNN() - self.setupRNN() - self.setupCTC() - - # setup optimizer to train NN - self.batchesTrained = 0 - self.learningRate = tf.compat.v1.placeholder(tf.float32, shape=[]) - self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(self.update_ops): - self.optimizer = tf.compat.v1.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) - - # initialize TF - (self.sess, self.saver) = self.setupTF() - - - def setupCNN(self): - "create CNN layers and return output of these layers" - cnnIn4d = tf.expand_dims(input=self.inputImgs, axis=3) - - # list of parameters for the layers - kernelVals = [5, 5, 3, 3, 3] - featureVals = [1, 32, 64, 128, 128, 256] - strideVals = poolVals = [(2,2), (2,2), (1,2), (1,2), (1,2)] - numLayers = len(strideVals) - - # create layers - pool = cnnIn4d # input to first CNN layer - for i in range(numLayers): - kernel = tf.Variable(tf.random.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1)) - conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1,1,1,1)) - conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train) - relu = tf.nn.relu(conv_norm) - pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1), strides=(1, strideVals[i][0], strideVals[i][1], 1), padding='VALID') - - self.cnnOut4d = pool - - - def setupRNN(self): - "create RNN layers and return output of these layers" - rnnIn3d = tf.squeeze(self.cnnOut4d, axis=[2]) - - # basic cells which is used to build RNN - numHidden = 256 - cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in range(2)] # 2 layers - - # stack basic cells - stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) - - # bidirectional RNN - # BxTxF -> BxTx2H - ((fw, bw), _) = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype) - - # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H - concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) - - # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC - kernel = tf.Variable(tf.random.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1)) - self.rnnOut3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2]) - - - def setupCTC(self): - "create CTC loss and decoder and return them" - # BxTxC -> TxBxC - self.ctcIn3dTBC = tf.transpose(a=self.rnnOut3d, perm=[1, 0, 2]) - # ground truth text as sparse tensor - self.gtTexts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]) , tf.compat.v1.placeholder(tf.int32, [None]), tf.compat.v1.placeholder(tf.int64, [2])) - - # calc loss for batch - self.seqLen = tf.compat.v1.placeholder(tf.int32, [None]) - self.loss = tf.reduce_mean(input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)) - - # calc loss for each element to compute label probability - self.savedCtcInput = tf.compat.v1.placeholder(tf.float32, shape=[Model.maxTextLen, None, len(self.charList) + 1]) - self.lossPerElement = tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, sequence_length=self.seqLen, ctc_merge_repeated=True) - - # decoder: either best path decoding or beam search decoding - if self.decoderType == DecoderType.BestPath: - self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen) - elif self.decoderType == DecoderType.BeamSearch: - self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50) - elif self.decoderType == DecoderType.WordBeamSearch: - # import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch) - word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so') - - # prepare information about language (dictionary, characters in dataset, characters forming words) - chars = str().join(self.charList) - wordChars = open('../model/wordCharList.txt').read().splitlines()[0] - corpus = open('../data/corpus.txt').read() - - # decode using the "Words" mode of word beam search - self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(self.ctcIn3dTBC, axis=2), 50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8')) - - - def setupTF(self): - "initialize TF" - print('Python: '+sys.version) - print('Tensorflow: '+tf.__version__) - - sess=tf.compat.v1.Session() # TF session - - saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file - modelDir = '../model/' - latestSnapshot = tf.train.latest_checkpoint(modelDir) # is there a saved model? - - # if model must be restored (for inference), there must be a snapshot - if self.mustRestore and not latestSnapshot: - raise Exception('No saved model found in: ' + modelDir) - - # load saved model if available - if latestSnapshot: - print('Init with stored values from ' + latestSnapshot) - saver.restore(sess, latestSnapshot) - else: - print('Init with new values') - sess.run(tf.compat.v1.global_variables_initializer()) - - return (sess,saver) - - - def toSparse(self, texts): - "put ground truth texts into sparse tensor for ctc_loss" - indices = [] - values = [] - shape = [len(texts), 0] # last entry must be max(labelList[i]) - - # go over all texts - for (batchElement, text) in enumerate(texts): - # convert to string of label (i.e. class-ids) - labelStr = [self.charList.index(c) for c in text] - # sparse tensor must have size of max. label-string - if len(labelStr) > shape[1]: - shape[1] = len(labelStr) - # put each label into sparse tensor - for (i, label) in enumerate(labelStr): - indices.append([batchElement, i]) - values.append(label) - - return (indices, values, shape) - - - def decoderOutputToText(self, ctcOutput, batchSize): - "extract texts from output of CTC decoder" - - # contains string of labels for each batch element - encodedLabelStrs = [[] for i in range(batchSize)] - - # word beam search: label strings terminated by blank - if self.decoderType == DecoderType.WordBeamSearch: - blank=len(self.charList) - for b in range(batchSize): - for label in ctcOutput[b]: - if label==blank: - break - encodedLabelStrs[b].append(label) - - # TF decoders: label strings are contained in sparse tensor - else: - # ctc returns tuple, first element is SparseTensor - decoded=ctcOutput[0][0] - - # go over all indices and save mapping: batch -> values - idxDict = { b : [] for b in range(batchSize) } - for (idx, idx2d) in enumerate(decoded.indices): - label = decoded.values[idx] - batchElement = idx2d[0] # index according to [b,t] - encodedLabelStrs[batchElement].append(label) - - # map labels to chars for all batch elements - return [str().join([self.charList[c] for c in labelStr]) for labelStr in encodedLabelStrs] - - - def trainBatch(self, batch): - "feed a batch into the NN to train it" - numBatchElements = len(batch.imgs) - sparse = self.toSparse(batch.gtTexts) - rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate - evalList = [self.optimizer, self.loss] - feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate, self.is_train: True} - (_, lossVal) = self.sess.run(evalList, feedDict) - self.batchesTrained += 1 - return lossVal - - - def dumpNNOutput(self, rnnOutput): - "dump the output of the NN to CSV file(s)" - dumpDir = '../dump/' - if not os.path.isdir(dumpDir): - os.mkdir(dumpDir) - - # iterate over all batch elements and create a CSV file for each one - maxT, maxB, maxC = rnnOutput.shape - for b in range(maxB): - csv = '' - for t in range(maxT): - for c in range(maxC): - csv += str(rnnOutput[t, b, c]) + ';' - csv += '\n' - fn = dumpDir + 'rnnOutput_'+str(b)+'.csv' - print('Write dump of NN to file: ' + fn) - with open(fn, 'w') as f: - f.write(csv) - - - def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False): - "feed a batch into the NN to recognize the texts" - - # decode, optionally save RNN output - numBatchElements = len(batch.imgs) - evalRnnOutput = self.dump or calcProbability - evalList = [self.decoder] + ([self.ctcIn3dTBC] if evalRnnOutput else []) - feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} - evalRes = self.sess.run(evalList, feedDict) - decoded = evalRes[0] - texts = self.decoderOutputToText(decoded, numBatchElements) - - # feed RNN output and recognized text into CTC loss to compute labeling probability - probs = None - if calcProbability: - sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts) - ctcInput = evalRes[1] - evalList = self.lossPerElement - feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} - lossVals = self.sess.run(evalList, feedDict) - probs = np.exp(-lossVals) - - # dump the output of the NN to CSV file(s) - if self.dump: - self.dumpNNOutput(evalRes[1]) - - return (texts, probs) - - - def save(self): - "save model to file" - self.snapID += 1 - self.saver.save(self.sess, '../model/snapshot', global_step=self.snapID) + "minimalistic TF model for HTR" + + # model constants + batchSize = 50 + imgSize = (128, 32) + maxTextLen = 32 + + def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False, dump=False): + "init model: add CNN, RNN and CTC and initialize TF" + self.dump = dump + self.charList = charList + self.decoderType = decoderType + self.mustRestore = mustRestore + self.snapID = 0 + + # Whether to use normalization over a batch or a population + self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') + + # input image batch + self.inputImgs = tf.compat.v1.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1])) + + # setup CNN, RNN and CTC + self.setupCNN() + self.setupRNN() + self.setupCTC() + + # setup optimizer to train NN + self.batchesTrained = 0 + self.learningRate = tf.compat.v1.placeholder(tf.float32, shape=[]) + self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(self.update_ops): + self.optimizer = tf.compat.v1.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) + + # initialize TF + (self.sess, self.saver) = self.setupTF() + + def setupCNN(self): + "create CNN layers and return output of these layers" + cnnIn4d = tf.expand_dims(input=self.inputImgs, axis=3) + + # list of parameters for the layers + kernelVals = [5, 5, 3, 3, 3] + featureVals = [1, 32, 64, 128, 128, 256] + strideVals = poolVals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)] + numLayers = len(strideVals) + + # create layers + pool = cnnIn4d # input to first CNN layer + for i in range(numLayers): + kernel = tf.Variable( + tf.random.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], + stddev=0.1)) + conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1)) + conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train) + relu = tf.nn.relu(conv_norm) + pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1), + strides=(1, strideVals[i][0], strideVals[i][1], 1), padding='VALID') + + self.cnnOut4d = pool + + def setupRNN(self): + "create RNN layers and return output of these layers" + rnnIn3d = tf.squeeze(self.cnnOut4d, axis=[2]) + + # basic cells which is used to build RNN + numHidden = 256 + cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in + range(2)] # 2 layers + + # stack basic cells + stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) + + # bidirectional RNN + # BxTxF -> BxTx2H + ((fw, bw), _) = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, + dtype=rnnIn3d.dtype) + + # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H + concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) + + # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC + kernel = tf.Variable(tf.random.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1)) + self.rnnOut3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2]) + + def setupCTC(self): + "create CTC loss and decoder and return them" + # BxTxC -> TxBxC + self.ctcIn3dTBC = tf.transpose(a=self.rnnOut3d, perm=[1, 0, 2]) + # ground truth text as sparse tensor + self.gtTexts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]), + tf.compat.v1.placeholder(tf.int32, [None]), + tf.compat.v1.placeholder(tf.int64, [2])) + + # calc loss for batch + self.seqLen = tf.compat.v1.placeholder(tf.int32, [None]) + self.loss = tf.reduce_mean(input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, + sequence_length=self.seqLen, + ctc_merge_repeated=True)) + + # calc loss for each element to compute label probability + self.savedCtcInput = tf.compat.v1.placeholder(tf.float32, + shape=[Model.maxTextLen, None, len(self.charList) + 1]) + self.lossPerElement = tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, + sequence_length=self.seqLen, ctc_merge_repeated=True) + + # decoder: either best path decoding or beam search decoding + if self.decoderType == DecoderType.BestPath: + self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen) + elif self.decoderType == DecoderType.BeamSearch: + self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, + beam_width=50) + elif self.decoderType == DecoderType.WordBeamSearch: + # import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch) + word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so') + + # prepare information about language (dictionary, characters in dataset, characters forming words) + chars = str().join(self.charList) + wordChars = open('../model/wordCharList.txt').read().splitlines()[0] + corpus = open('../data/corpus.txt').read() + + # decode using the "Words" mode of word beam search + self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(self.ctcIn3dTBC, axis=2), 50, 'Words', + 0.0, corpus.encode('utf8'), chars.encode('utf8'), + wordChars.encode('utf8')) + + def setupTF(self): + "initialize TF" + print('Python: ' + sys.version) + print('Tensorflow: ' + tf.__version__) + + sess = tf.compat.v1.Session() # TF session + + saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file + modelDir = '../model/' + latestSnapshot = tf.train.latest_checkpoint(modelDir) # is there a saved model? + + # if model must be restored (for inference), there must be a snapshot + if self.mustRestore and not latestSnapshot: + raise Exception('No saved model found in: ' + modelDir) + + # load saved model if available + if latestSnapshot: + print('Init with stored values from ' + latestSnapshot) + saver.restore(sess, latestSnapshot) + else: + print('Init with new values') + sess.run(tf.compat.v1.global_variables_initializer()) + + return (sess, saver) + + def toSparse(self, texts): + "put ground truth texts into sparse tensor for ctc_loss" + indices = [] + values = [] + shape = [len(texts), 0] # last entry must be max(labelList[i]) + + # go over all texts + for (batchElement, text) in enumerate(texts): + # convert to string of label (i.e. class-ids) + labelStr = [self.charList.index(c) for c in text] + # sparse tensor must have size of max. label-string + if len(labelStr) > shape[1]: + shape[1] = len(labelStr) + # put each label into sparse tensor + for (i, label) in enumerate(labelStr): + indices.append([batchElement, i]) + values.append(label) + + return (indices, values, shape) + + def decoderOutputToText(self, ctcOutput, batchSize): + "extract texts from output of CTC decoder" + + # contains string of labels for each batch element + encodedLabelStrs = [[] for i in range(batchSize)] + + # word beam search: label strings terminated by blank + if self.decoderType == DecoderType.WordBeamSearch: + blank = len(self.charList) + for b in range(batchSize): + for label in ctcOutput[b]: + if label == blank: + break + encodedLabelStrs[b].append(label) + + # TF decoders: label strings are contained in sparse tensor + else: + # ctc returns tuple, first element is SparseTensor + decoded = ctcOutput[0][0] + + # go over all indices and save mapping: batch -> values + idxDict = {b: [] for b in range(batchSize)} + for (idx, idx2d) in enumerate(decoded.indices): + label = decoded.values[idx] + batchElement = idx2d[0] # index according to [b,t] + encodedLabelStrs[batchElement].append(label) + + # map labels to chars for all batch elements + return [str().join([self.charList[c] for c in labelStr]) for labelStr in encodedLabelStrs] + + def trainBatch(self, batch): + "feed a batch into the NN to train it" + numBatchElements = len(batch.imgs) + sparse = self.toSparse(batch.gtTexts) + rate = 0.01 if self.batchesTrained < 10 else ( + 0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate + evalList = [self.optimizer, self.loss] + feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, + self.seqLen: [Model.maxTextLen] * numBatchElements, self.learningRate: rate, self.is_train: True} + (_, lossVal) = self.sess.run(evalList, feedDict) + self.batchesTrained += 1 + return lossVal + + def dumpNNOutput(self, rnnOutput): + "dump the output of the NN to CSV file(s)" + dumpDir = '../dump/' + if not os.path.isdir(dumpDir): + os.mkdir(dumpDir) + + # iterate over all batch elements and create a CSV file for each one + maxT, maxB, maxC = rnnOutput.shape + for b in range(maxB): + csv = '' + for t in range(maxT): + for c in range(maxC): + csv += str(rnnOutput[t, b, c]) + ';' + csv += '\n' + fn = dumpDir + 'rnnOutput_' + str(b) + '.csv' + print('Write dump of NN to file: ' + fn) + with open(fn, 'w') as f: + f.write(csv) + + def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False): + "feed a batch into the NN to recognize the texts" + + # decode, optionally save RNN output + numBatchElements = len(batch.imgs) + evalRnnOutput = self.dump or calcProbability + evalList = [self.decoder] + ([self.ctcIn3dTBC] if evalRnnOutput else []) + feedDict = {self.inputImgs: batch.imgs, self.seqLen: [Model.maxTextLen] * numBatchElements, + self.is_train: False} + evalRes = self.sess.run(evalList, feedDict) + decoded = evalRes[0] + texts = self.decoderOutputToText(decoded, numBatchElements) + + # feed RNN output and recognized text into CTC loss to compute labeling probability + probs = None + if calcProbability: + sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts) + ctcInput = evalRes[1] + evalList = self.lossPerElement + feedDict = {self.savedCtcInput: ctcInput, self.gtTexts: sparse, + self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: False} + lossVals = self.sess.run(evalList, feedDict) + probs = np.exp(-lossVals) + + # dump the output of the NN to CSV file(s) + if self.dump: + self.dumpNNOutput(evalRes[1]) + + return (texts, probs) + + def save(self): + "save model to file" + self.snapID += 1 + self.saver.save(self.sess, '../model/snapshot', global_step=self.snapID) diff --git a/src/SamplePreprocessor.py b/src/SamplePreprocessor.py index 9796d2e98162e2d0fcfd5b85eee28e44b504da8d..1d571e9f33a858c58021ef681996b19c2c306424 100644 --- a/src/SamplePreprocessor.py +++ b/src/SamplePreprocessor.py @@ -2,42 +2,43 @@ from __future__ import division from __future__ import print_function import random -import numpy as np + import cv2 +import numpy as np def preprocess(img, imgSize, dataAugmentation=False): - "put img into target img of size imgSize, transpose for TF and normalize gray-values" - - # there are damaged files in IAM dataset - just use black image instead - if img is None: - img = np.zeros([imgSize[1], imgSize[0]]) - - # increase dataset size by applying random stretches to the images - if dataAugmentation: - stretch = (random.random() - 0.5) # -0.5 .. +0.5 - wStretched = max(int(img.shape[1] * (1 + stretch)), 1) # random width, but at least 1 - img = cv2.resize(img, (wStretched, img.shape[0])) # stretch horizontally by factor 0.5 .. 1.5 - - # create target image and copy sample image into it - (wt, ht) = imgSize - (h, w) = img.shape - fx = w / wt - fy = h / ht - f = max(fx, fy) - newSize = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht) - img = cv2.resize(img, newSize) - target = np.ones([ht, wt]) * 255 - target[0:newSize[1], 0:newSize[0]] = img - - # transpose for TF - img = cv2.transpose(target) - - # normalize - (m, s) = cv2.meanStdDev(img) - m = m[0][0] - s = s[0][0] - img = img - m - img = img / s if s>0 else img - return img - + "put img into target img of size imgSize, transpose for TF and normalize gray-values" + + # there are damaged files in IAM dataset - just use black image instead + if img is None: + img = np.zeros([imgSize[1], imgSize[0]]) + + # increase dataset size by applying random stretches to the images + if dataAugmentation: + stretch = (random.random() - 0.5) # -0.5 .. +0.5 + wStretched = max(int(img.shape[1] * (1 + stretch)), 1) # random width, but at least 1 + img = cv2.resize(img, (wStretched, img.shape[0])) # stretch horizontally by factor 0.5 .. 1.5 + + # create target image and copy sample image into it + (wt, ht) = imgSize + (h, w) = img.shape + fx = w / wt + fy = h / ht + f = max(fx, fy) + newSize = (max(min(wt, int(w / f)), 1), + max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht) + img = cv2.resize(img, newSize) + target = np.ones([ht, wt]) * 255 + target[0:newSize[1], 0:newSize[0]] = img + + # transpose for TF + img = cv2.transpose(target) + + # normalize + (m, s) = cv2.meanStdDev(img) + m = m[0][0] + s = s[0][0] + img = img - m + img = img / s if s > 0 else img + return img diff --git a/src/analyze.py b/src/analyze.py index 0796c0215216543e6c3c992b01f8a503147dfdf8..d34a97b93cdf1944289fb94453f16ea27ef82e45 100644 --- a/src/analyze.py +++ b/src/analyze.py @@ -1,13 +1,15 @@ from __future__ import division from __future__ import print_function -import sys +import copy import math import pickle -import copy -import numpy as np +import sys + import cv2 import matplotlib.pyplot as plt +import numpy as np + from DataLoader import Batch from Model import Model, DecoderType from SamplePreprocessor import preprocess @@ -15,145 +17,143 @@ from SamplePreprocessor import preprocess # constants like filepaths class Constants: - "filenames and paths to data" - fnCharList = '../model/charList.txt' - fnAnalyze = '../data/analyze.png' - fnPixelRelevance = '../data/pixelRelevance.npy' - fnTranslationInvariance = '../data/translationInvariance.npy' - fnTranslationInvarianceTexts = '../data/translationInvarianceTexts.pickle' - gtText = 'are' - distribution = 'histogram' # 'histogram' or 'uniform' + "filenames and paths to data" + fnCharList = '../model/charList.txt' + fnAnalyze = '../data/analyze.png' + fnPixelRelevance = '../data/pixelRelevance.npy' + fnTranslationInvariance = '../data/translationInvariance.npy' + fnTranslationInvarianceTexts = '../data/translationInvarianceTexts.pickle' + gtText = 'are' + distribution = 'histogram' # 'histogram' or 'uniform' def odds(val): - return val / (1 - val) + return val / (1 - val) def weightOfEvidence(origProb, margProb): - return math.log2(odds(origProb)) - math.log2(odds(margProb)) - - + return math.log2(odds(origProb)) - math.log2(odds(margProb)) + + def analyzePixelRelevance(): - "simplified implementation of paper: Zintgraf et al - Visualizing Deep Neural Network Decisions: Prediction Difference Analysis" - - # setup model - model = Model(open(Constants.fnCharList).read(), DecoderType.BestPath, mustRestore=True) - - # read image and specify ground-truth text - img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE) - (w, h) = img.shape - assert Model.imgSize[1] == w - - # compute probability of gt text in original image - batch = Batch([Constants.gtText], [preprocess(img, Model.imgSize)]) - (_, probs) = model.inferBatch(batch, calcProbability=True, probabilityOfGT=True) - origProb = probs[0] - - grayValues = [0, 63, 127, 191, 255] - if Constants.distribution == 'histogram': - bins = [0, 31, 95, 159, 223, 255] - (hist, _) = np.histogram(img, bins=bins) - pixelProb = hist / sum(hist) - elif Constants.distribution == 'uniform': - pixelProb = [1.0 / len(grayValues) for _ in grayValues] - else: - raise Exception('unknown value for Constants.distribution') - - # iterate over all pixels in image - pixelRelevance = np.zeros(img.shape, np.float32) - for x in range(w): - for y in range(h): - - # try a subset of possible grayvalues of pixel (x,y) - imgsMarginalized = [] - for g in grayValues: - imgChanged = copy.deepcopy(img) - imgChanged[x, y] = g - imgsMarginalized.append(preprocess(imgChanged, Model.imgSize)) - - # put them all into one batch - batch = Batch([Constants.gtText]*len(imgsMarginalized), imgsMarginalized) - - # compute probabilities - (_, probs) = model.inferBatch(batch, calcProbability=True, probabilityOfGT=True) - - # marginalize over pixel value (assume uniform distribution) - margProb = sum([probs[i] * pixelProb[i] for i in range(len(grayValues))]) - - pixelRelevance[x, y] = weightOfEvidence(origProb, margProb) - - print(x, y, pixelRelevance[x, y], origProb, margProb) - - np.save(Constants.fnPixelRelevance, pixelRelevance) + "simplified implementation of paper: Zintgraf et al - Visualizing Deep Neural Network Decisions: Prediction Difference Analysis" + + # setup model + model = Model(open(Constants.fnCharList).read(), DecoderType.BestPath, mustRestore=True) + + # read image and specify ground-truth text + img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE) + (w, h) = img.shape + assert Model.imgSize[1] == w + + # compute probability of gt text in original image + batch = Batch([Constants.gtText], [preprocess(img, Model.imgSize)]) + (_, probs) = model.inferBatch(batch, calcProbability=True, probabilityOfGT=True) + origProb = probs[0] + + grayValues = [0, 63, 127, 191, 255] + if Constants.distribution == 'histogram': + bins = [0, 31, 95, 159, 223, 255] + (hist, _) = np.histogram(img, bins=bins) + pixelProb = hist / sum(hist) + elif Constants.distribution == 'uniform': + pixelProb = [1.0 / len(grayValues) for _ in grayValues] + else: + raise Exception('unknown value for Constants.distribution') + + # iterate over all pixels in image + pixelRelevance = np.zeros(img.shape, np.float32) + for x in range(w): + for y in range(h): + + # try a subset of possible grayvalues of pixel (x,y) + imgsMarginalized = [] + for g in grayValues: + imgChanged = copy.deepcopy(img) + imgChanged[x, y] = g + imgsMarginalized.append(preprocess(imgChanged, Model.imgSize)) + + # put them all into one batch + batch = Batch([Constants.gtText] * len(imgsMarginalized), imgsMarginalized) + + # compute probabilities + (_, probs) = model.inferBatch(batch, calcProbability=True, probabilityOfGT=True) + + # marginalize over pixel value (assume uniform distribution) + margProb = sum([probs[i] * pixelProb[i] for i in range(len(grayValues))]) + + pixelRelevance[x, y] = weightOfEvidence(origProb, margProb) + + print(x, y, pixelRelevance[x, y], origProb, margProb) + + np.save(Constants.fnPixelRelevance, pixelRelevance) def analyzeTranslationInvariance(): - # setup model - model = Model(open(Constants.fnCharList).read(), DecoderType.BestPath, mustRestore=True) - - # read image and specify ground-truth text - img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE) - (w, h) = img.shape - assert Model.imgSize[1] == w - - imgList = [] - for dy in range(Model.imgSize[0]-h+1): - targetImg = np.ones((Model.imgSize[1], Model.imgSize[0])) * 255 - targetImg[:,dy:h+dy] = img - imgList.append(preprocess(targetImg, Model.imgSize)) - - # put images and gt texts into batch - batch = Batch([Constants.gtText]*len(imgList), imgList) - - # compute probabilities - (texts, probs) = model.inferBatch(batch, calcProbability=True, probabilityOfGT=True) - - # save results to file - f = open(Constants.fnTranslationInvarianceTexts, 'wb') - pickle.dump(texts, f) - f.close() - np.save(Constants.fnTranslationInvariance, probs) + # setup model + model = Model(open(Constants.fnCharList).read(), DecoderType.BestPath, mustRestore=True) + + # read image and specify ground-truth text + img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE) + (w, h) = img.shape + assert Model.imgSize[1] == w + + imgList = [] + for dy in range(Model.imgSize[0] - h + 1): + targetImg = np.ones((Model.imgSize[1], Model.imgSize[0])) * 255 + targetImg[:, dy:h + dy] = img + imgList.append(preprocess(targetImg, Model.imgSize)) + + # put images and gt texts into batch + batch = Batch([Constants.gtText] * len(imgList), imgList) + + # compute probabilities + (texts, probs) = model.inferBatch(batch, calcProbability=True, probabilityOfGT=True) + + # save results to file + f = open(Constants.fnTranslationInvarianceTexts, 'wb') + pickle.dump(texts, f) + f.close() + np.save(Constants.fnTranslationInvariance, probs) def showResults(): - # 1. pixel relevance - pixelRelevance = np.load(Constants.fnPixelRelevance) - plt.figure('Pixel relevance') - - plt.imshow(pixelRelevance, cmap=plt.cm.jet, vmin=-0.25, vmax=0.25) - plt.colorbar() - - img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE) - plt.imshow(img, cmap=plt.cm.gray, alpha=.4) - - - # 2. translation invariance - probs = np.load(Constants.fnTranslationInvariance) - f = open(Constants.fnTranslationInvarianceTexts, 'rb') - texts = pickle.load(f) - texts = ['%d:'%i + texts[i] for i in range(len(texts))] - f.close() - - plt.figure('Translation invariance') - - plt.plot(probs, 'o-') - plt.xticks(np.arange(len(texts)), texts, rotation='vertical') - plt.xlabel('horizontal translation and best path') - plt.ylabel('text probability of "%s"'%Constants.gtText) - - # show both plots - plt.show() + # 1. pixel relevance + pixelRelevance = np.load(Constants.fnPixelRelevance) + plt.figure('Pixel relevance') + plt.imshow(pixelRelevance, cmap=plt.cm.jet, vmin=-0.25, vmax=0.25) + plt.colorbar() -if __name__ == '__main__': - if len(sys.argv)>1: - if sys.argv[1]=='--relevance': - print('Analyze pixel relevance') - analyzePixelRelevance() - elif sys.argv[1]=='--invariance': - print('Analyze translation invariance') - analyzeTranslationInvariance() - else: - print('Show results') - showResults() + img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE) + plt.imshow(img, cmap=plt.cm.gray, alpha=.4) + + # 2. translation invariance + probs = np.load(Constants.fnTranslationInvariance) + f = open(Constants.fnTranslationInvarianceTexts, 'rb') + texts = pickle.load(f) + texts = ['%d:' % i + texts[i] for i in range(len(texts))] + f.close() + + plt.figure('Translation invariance') + plt.plot(probs, 'o-') + plt.xticks(np.arange(len(texts)), texts, rotation='vertical') + plt.xlabel('horizontal translation and best path') + plt.ylabel('text probability of "%s"' % Constants.gtText) + + # show both plots + plt.show() + + +if __name__ == '__main__': + if len(sys.argv) > 1: + if sys.argv[1] == '--relevance': + print('Analyze pixel relevance') + analyzePixelRelevance() + elif sys.argv[1] == '--invariance': + print('Analyze translation invariance') + analyzeTranslationInvariance() + else: + print('Show results') + showResults() diff --git a/src/main.py b/src/main.py index 1e712830526d579f001113c4e4a2b3a6de8318ca..2a53c8d907f392f1d63c869fb79f7c2dce19788a 100644 --- a/src/main.py +++ b/src/main.py @@ -1,8 +1,9 @@ from __future__ import division from __future__ import print_function -import sys import argparse +import sys + import cv2 import editdistance from DataLoader import DataLoader, Batch @@ -11,135 +12,138 @@ from SamplePreprocessor import preprocess class FilePaths: - "filenames and paths to data" - fnCharList = '../model/charList.txt' - fnAccuracy = '../model/accuracy.txt' - fnTrain = '../data/' - fnInfer = '../data/test.png' - fnCorpus = '../data/corpus.txt' + "filenames and paths to data" + fnCharList = '../model/charList.txt' + fnAccuracy = '../model/accuracy.txt' + fnTrain = '../data/' + fnInfer = '../data/test.png' + fnCorpus = '../data/corpus.txt' def train(model, loader): - "train NN" - epoch = 0 # number of training epochs since start - bestCharErrorRate = float('inf') # best valdiation character error rate - noImprovementSince = 0 # number of epochs no improvement of character error rate occured - earlyStopping = 5 # stop training after this number of epochs without improvement - while True: - epoch += 1 - print('Epoch:', epoch) - - # train - print('Train NN') - loader.trainSet() - while loader.hasNext(): - iterInfo = loader.getIteratorInfo() - batch = loader.getNext() - loss = model.trainBatch(batch) - print('Batch:', iterInfo[0],'/', iterInfo[1], 'Loss:', loss) - - # validate - charErrorRate = validate(model, loader) - - # if best validation accuracy so far, save model parameters - if charErrorRate < bestCharErrorRate: - print('Character error rate improved, save model') - bestCharErrorRate = charErrorRate - noImprovementSince = 0 - model.save() - open(FilePaths.fnAccuracy, 'w').write('Validation character error rate of saved model: %f%%' % (charErrorRate*100.0)) - else: - print('Character error rate not improved') - noImprovementSince += 1 - - # stop training if no more improvement in the last x epochs - if noImprovementSince >= earlyStopping: - print('No more improvement since %d epochs. Training stopped.' % earlyStopping) - break + "train NN" + epoch = 0 # number of training epochs since start + bestCharErrorRate = float('inf') # best valdiation character error rate + noImprovementSince = 0 # number of epochs no improvement of character error rate occured + earlyStopping = 5 # stop training after this number of epochs without improvement + while True: + epoch += 1 + print('Epoch:', epoch) + + # train + print('Train NN') + loader.trainSet() + while loader.hasNext(): + iterInfo = loader.getIteratorInfo() + batch = loader.getNext() + loss = model.trainBatch(batch) + print('Batch:', iterInfo[0], '/', iterInfo[1], 'Loss:', loss) + + # validate + charErrorRate = validate(model, loader) + + # if best validation accuracy so far, save model parameters + if charErrorRate < bestCharErrorRate: + print('Character error rate improved, save model') + bestCharErrorRate = charErrorRate + noImprovementSince = 0 + model.save() + open(FilePaths.fnAccuracy, 'w').write( + 'Validation character error rate of saved model: %f%%' % (charErrorRate * 100.0)) + else: + print('Character error rate not improved') + noImprovementSince += 1 + + # stop training if no more improvement in the last x epochs + if noImprovementSince >= earlyStopping: + print('No more improvement since %d epochs. Training stopped.' % earlyStopping) + break def validate(model, loader): - "validate NN" - print('Validate NN') - loader.validationSet() - numCharErr = 0 - numCharTotal = 0 - numWordOK = 0 - numWordTotal = 0 - while loader.hasNext(): - iterInfo = loader.getIteratorInfo() - print('Batch:', iterInfo[0],'/', iterInfo[1]) - batch = loader.getNext() - (recognized, _) = model.inferBatch(batch) - - print('Ground truth -> Recognized') - for i in range(len(recognized)): - numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0 - numWordTotal += 1 - dist = editdistance.eval(recognized[i], batch.gtTexts[i]) - numCharErr += dist - numCharTotal += len(batch.gtTexts[i]) - print('[OK]' if dist==0 else '[ERR:%d]' % dist,'"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"') - - # print validation result - charErrorRate = numCharErr / numCharTotal - wordAccuracy = numWordOK / numWordTotal - print('Character error rate: %f%%. Word accuracy: %f%%.' % (charErrorRate*100.0, wordAccuracy*100.0)) - return charErrorRate + "validate NN" + print('Validate NN') + loader.validationSet() + numCharErr = 0 + numCharTotal = 0 + numWordOK = 0 + numWordTotal = 0 + while loader.hasNext(): + iterInfo = loader.getIteratorInfo() + print('Batch:', iterInfo[0], '/', iterInfo[1]) + batch = loader.getNext() + (recognized, _) = model.inferBatch(batch) + + print('Ground truth -> Recognized') + for i in range(len(recognized)): + numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0 + numWordTotal += 1 + dist = editdistance.eval(recognized[i], batch.gtTexts[i]) + numCharErr += dist + numCharTotal += len(batch.gtTexts[i]) + print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gtTexts[i] + '"', '->', + '"' + recognized[i] + '"') + + # print validation result + charErrorRate = numCharErr / numCharTotal + wordAccuracy = numWordOK / numWordTotal + print('Character error rate: %f%%. Word accuracy: %f%%.' % (charErrorRate * 100.0, wordAccuracy * 100.0)) + return charErrorRate def infer(model, fnImg): - "recognize text in image provided by file path" - img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize) - batch = Batch(None, [img]) - (recognized, probability) = model.inferBatch(batch, True) - print('Recognized:', '"' + recognized[0] + '"') - print('Probability:', probability[0]) + "recognize text in image provided by file path" + img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize) + batch = Batch(None, [img]) + (recognized, probability) = model.inferBatch(batch, True) + print('Recognized:', '"' + recognized[0] + '"') + print('Probability:', probability[0]) def main(): - "main function" - # optional command line args - parser = argparse.ArgumentParser() - parser.add_argument('--train', help='train the NN', action='store_true') - parser.add_argument('--validate', help='validate the NN', action='store_true') - parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true') - parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true') - parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true') - - args = parser.parse_args() - - decoderType = DecoderType.BestPath - if args.beamsearch: - decoderType = DecoderType.BeamSearch - elif args.wordbeamsearch: - decoderType = DecoderType.WordBeamSearch - - # train or validate on IAM dataset - if args.train or args.validate: - # load training data, create TF model - loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen) - - # save characters of model for inference mode - open(FilePaths.fnCharList, 'w').write(str().join(loader.charList)) - - # save words contained in dataset into file - open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords)) - - # execute training or validation - if args.train: - model = Model(loader.charList, decoderType) - train(model, loader) - elif args.validate: - model = Model(loader.charList, decoderType, mustRestore=True) - validate(model, loader) - - # infer text on test image - else: - print(open(FilePaths.fnAccuracy).read()) - model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump) - infer(model, FilePaths.fnInfer) + "main function" + # optional command line args + parser = argparse.ArgumentParser() + parser.add_argument('--train', help='train the NN', action='store_true') + parser.add_argument('--validate', help='validate the NN', action='store_true') + parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true') + parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', + action='store_true') + parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true') + + args = parser.parse_args() + + decoderType = DecoderType.BestPath + if args.beamsearch: + decoderType = DecoderType.BeamSearch + elif args.wordbeamsearch: + decoderType = DecoderType.WordBeamSearch + + # train or validate on IAM dataset + if args.train or args.validate: + # load training data, create TF model + loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen) + + # save characters of model for inference mode + open(FilePaths.fnCharList, 'w').write(str().join(loader.charList)) + + # save words contained in dataset into file + open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords)) + + # execute training or validation + if args.train: + model = Model(loader.charList, decoderType) + train(model, loader) + elif args.validate: + model = Model(loader.charList, decoderType, mustRestore=True) + validate(model, loader) + + # infer text on test image + else: + print(open(FilePaths.fnAccuracy).read()) + model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump) + infer(model, FilePaths.fnInfer) if __name__ == '__main__': - main() + main()