Skip to content
Snippets Groups Projects
Commit 45162d8a authored by Harald Scheidl's avatar Harald Scheidl
Browse files

reformat code, add note about TF2 in readme

parent ec00c1a3
Branches
No related tags found
No related merge requests found
data/words/
data/words.txt
data/words.7z
src/__pycache__/
model/checkpoint
model/snapshot-*
......
# Handwritten Text Recognition with TensorFlow
**Update 2020: code is compatible with TF2**
Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset.
This Neural Network (NN) model recognizes the text contained in the images of segmented words as shown in the illustration below.
As these word-images are smaller than images of complete text-lines, the NN can be kept small and training on the CPU is feasible.
......
......@@ -3,13 +3,16 @@ from __future__ import print_function
import os
import random
import numpy as np
import cv2
import numpy as np
from SamplePreprocessor import preprocess
class Sample:
"sample from the dataset"
def __init__(self, gtText, filePath):
self.gtText = gtText
self.filePath = filePath
......@@ -17,6 +20,7 @@ class Sample:
class Batch:
"batch containing images and ground truth texts"
def __init__(self, gtTexts, imgs):
self.imgs = np.stack(imgs, axis=0)
self.gtTexts = gtTexts
......@@ -50,7 +54,8 @@ class DataLoader:
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
fileNameSplit = lineSplit[0].split('-')
fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + lineSplit[0] + '.png'
fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + \
lineSplit[0] + '.png'
# GT text are columns starting at 9
gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen)
......@@ -87,7 +92,6 @@ class DataLoader:
# list of all chars in dataset
self.charList = sorted(list(chars))
def truncateLabel(self, text, maxTextLen):
# ctc_loss can't compute loss if it cannot find a mapping between text label and input
# labels. Repeat letters cost double because of the blank symbol needing to be inserted.
......@@ -102,7 +106,6 @@ class DataLoader:
return text[:i]
return text
def trainSet(self):
"switch to randomly chosen subset of training set"
self.dataAugmentation = True
......@@ -110,30 +113,26 @@ class DataLoader:
random.shuffle(self.trainSamples)
self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch]
def validationSet(self):
"switch to validation set"
self.dataAugmentation = False
self.currIdx = 0
self.samples = self.validationSamples
def getIteratorInfo(self):
"current batch index and overall number of batches"
return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize)
def hasNext(self):
"iterator"
return self.currIdx + self.batchSize <= len(self.samples)
def getNext(self):
"iterator"
batchRange = range(self.currIdx, self.currIdx + self.batchSize)
gtTexts = [self.samples[i].gtText for i in batchRange]
imgs = [preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation) for i in batchRange]
imgs = [
preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation)
for i in batchRange]
self.currIdx += self.batchSize
return Batch(gtTexts, imgs)
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
import tensorflow as tf
import os
import sys
import tensorflow as tf
# Disable eagre
# Disable eager mode
tf.compat.v1.disable_eager_execution()
class DecoderType:
BestPath = 0
BeamSearch = 1
......@@ -52,7 +53,6 @@ class Model:
# initialize TF
(self.sess, self.saver) = self.setupTF()
def setupCNN(self):
"create CNN layers and return output of these layers"
cnnIn4d = tf.expand_dims(input=self.inputImgs, axis=3)
......@@ -66,29 +66,33 @@ class Model:
# create layers
pool = cnnIn4d # input to first CNN layer
for i in range(numLayers):
kernel = tf.Variable(tf.random.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1))
kernel = tf.Variable(
tf.random.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]],
stddev=0.1))
conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1))
conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train)
relu = tf.nn.relu(conv_norm)
pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1), strides=(1, strideVals[i][0], strideVals[i][1], 1), padding='VALID')
pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1),
strides=(1, strideVals[i][0], strideVals[i][1], 1), padding='VALID')
self.cnnOut4d = pool
def setupRNN(self):
"create RNN layers and return output of these layers"
rnnIn3d = tf.squeeze(self.cnnOut4d, axis=[2])
# basic cells which is used to build RNN
numHidden = 256
cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in range(2)] # 2 layers
cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in
range(2)] # 2 layers
# stack basic cells
stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
# bidirectional RNN
# BxTxF -> BxTx2H
((fw, bw), _) = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype)
((fw, bw), _) = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d,
dtype=rnnIn3d.dtype)
# BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)
......@@ -97,27 +101,33 @@ class Model:
kernel = tf.Variable(tf.random.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1))
self.rnnOut3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2])
def setupCTC(self):
"create CTC loss and decoder and return them"
# BxTxC -> TxBxC
self.ctcIn3dTBC = tf.transpose(a=self.rnnOut3d, perm=[1, 0, 2])
# ground truth text as sparse tensor
self.gtTexts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]) , tf.compat.v1.placeholder(tf.int32, [None]), tf.compat.v1.placeholder(tf.int64, [2]))
self.gtTexts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]),
tf.compat.v1.placeholder(tf.int32, [None]),
tf.compat.v1.placeholder(tf.int64, [2]))
# calc loss for batch
self.seqLen = tf.compat.v1.placeholder(tf.int32, [None])
self.loss = tf.reduce_mean(input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True))
self.loss = tf.reduce_mean(input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC,
sequence_length=self.seqLen,
ctc_merge_repeated=True))
# calc loss for each element to compute label probability
self.savedCtcInput = tf.compat.v1.placeholder(tf.float32, shape=[Model.maxTextLen, None, len(self.charList) + 1])
self.lossPerElement = tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, sequence_length=self.seqLen, ctc_merge_repeated=True)
self.savedCtcInput = tf.compat.v1.placeholder(tf.float32,
shape=[Model.maxTextLen, None, len(self.charList) + 1])
self.lossPerElement = tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput,
sequence_length=self.seqLen, ctc_merge_repeated=True)
# decoder: either best path decoding or beam search decoding
if self.decoderType == DecoderType.BestPath:
self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen)
elif self.decoderType == DecoderType.BeamSearch:
self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50)
self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen,
beam_width=50)
elif self.decoderType == DecoderType.WordBeamSearch:
# import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch)
word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so')
......@@ -128,8 +138,9 @@ class Model:
corpus = open('../data/corpus.txt').read()
# decode using the "Words" mode of word beam search
self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(self.ctcIn3dTBC, axis=2), 50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))
self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(self.ctcIn3dTBC, axis=2), 50, 'Words',
0.0, corpus.encode('utf8'), chars.encode('utf8'),
wordChars.encode('utf8'))
def setupTF(self):
"initialize TF"
......@@ -156,7 +167,6 @@ class Model:
return (sess, saver)
def toSparse(self, texts):
"put ground truth texts into sparse tensor for ctc_loss"
indices = []
......@@ -177,7 +187,6 @@ class Model:
return (indices, values, shape)
def decoderOutputToText(self, ctcOutput, batchSize):
"extract texts from output of CTC decoder"
......@@ -208,19 +217,19 @@ class Model:
# map labels to chars for all batch elements
return [str().join([self.charList[c] for c in labelStr]) for labelStr in encodedLabelStrs]
def trainBatch(self, batch):
"feed a batch into the NN to train it"
numBatchElements = len(batch.imgs)
sparse = self.toSparse(batch.gtTexts)
rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
rate = 0.01 if self.batchesTrained < 10 else (
0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
evalList = [self.optimizer, self.loss]
feedDict = {self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * numBatchElements, self.learningRate : rate, self.is_train: True}
feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse,
self.seqLen: [Model.maxTextLen] * numBatchElements, self.learningRate: rate, self.is_train: True}
(_, lossVal) = self.sess.run(evalList, feedDict)
self.batchesTrained += 1
return lossVal
def dumpNNOutput(self, rnnOutput):
"dump the output of the NN to CSV file(s)"
dumpDir = '../dump/'
......@@ -240,7 +249,6 @@ class Model:
with open(fn, 'w') as f:
f.write(csv)
def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False):
"feed a batch into the NN to recognize the texts"
......@@ -248,7 +256,8 @@ class Model:
numBatchElements = len(batch.imgs)
evalRnnOutput = self.dump or calcProbability
evalList = [self.decoder] + ([self.ctcIn3dTBC] if evalRnnOutput else [])
feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
feedDict = {self.inputImgs: batch.imgs, self.seqLen: [Model.maxTextLen] * numBatchElements,
self.is_train: False}
evalRes = self.sess.run(evalList, feedDict)
decoded = evalRes[0]
texts = self.decoderOutputToText(decoded, numBatchElements)
......@@ -259,7 +268,8 @@ class Model:
sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts)
ctcInput = evalRes[1]
evalList = self.lossPerElement
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
feedDict = {self.savedCtcInput: ctcInput, self.gtTexts: sparse,
self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: False}
lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals)
......@@ -269,7 +279,6 @@ class Model:
return (texts, probs)
def save(self):
"save model to file"
self.snapID += 1
......
......@@ -2,8 +2,9 @@ from __future__ import division
from __future__ import print_function
import random
import numpy as np
import cv2
import numpy as np
def preprocess(img, imgSize, dataAugmentation=False):
......@@ -25,7 +26,8 @@ def preprocess(img, imgSize, dataAugmentation=False):
fx = w / wt
fy = h / ht
f = max(fx, fy)
newSize = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht)
newSize = (max(min(wt, int(w / f)), 1),
max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht)
img = cv2.resize(img, newSize)
target = np.ones([ht, wt]) * 255
target[0:newSize[1], 0:newSize[0]] = img
......@@ -40,4 +42,3 @@ def preprocess(img, imgSize, dataAugmentation=False):
img = img - m
img = img / s if s > 0 else img
return img
from __future__ import division
from __future__ import print_function
import sys
import copy
import math
import pickle
import copy
import numpy as np
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
from DataLoader import Batch
from Model import Model, DecoderType
from SamplePreprocessor import preprocess
......@@ -126,7 +128,6 @@ def showResults():
img = cv2.imread(Constants.fnAnalyze, cv2.IMREAD_GRAYSCALE)
plt.imshow(img, cmap=plt.cm.gray, alpha=.4)
# 2. translation invariance
probs = np.load(Constants.fnTranslationInvariance)
f = open(Constants.fnTranslationInvarianceTexts, 'rb')
......@@ -156,4 +157,3 @@ if __name__ == '__main__':
else:
print('Show results')
showResults()
from __future__ import division
from __future__ import print_function
import sys
import argparse
import sys
import cv2
import editdistance
from DataLoader import DataLoader, Batch
......@@ -47,7 +48,8 @@ def train(model, loader):
bestCharErrorRate = charErrorRate
noImprovementSince = 0
model.save()
open(FilePaths.fnAccuracy, 'w').write('Validation character error rate of saved model: %f%%' % (charErrorRate*100.0))
open(FilePaths.fnAccuracy, 'w').write(
'Validation character error rate of saved model: %f%%' % (charErrorRate * 100.0))
else:
print('Character error rate not improved')
noImprovementSince += 1
......@@ -79,7 +81,8 @@ def validate(model, loader):
dist = editdistance.eval(recognized[i], batch.gtTexts[i])
numCharErr += dist
numCharTotal += len(batch.gtTexts[i])
print('[OK]' if dist==0 else '[ERR:%d]' % dist,'"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"')
print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gtTexts[i] + '"', '->',
'"' + recognized[i] + '"')
# print validation result
charErrorRate = numCharErr / numCharTotal
......@@ -104,7 +107,8 @@ def main():
parser.add_argument('--train', help='train the NN', action='store_true')
parser.add_argument('--validate', help='validate the NN', action='store_true')
parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true')
parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true')
parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding',
action='store_true')
parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment