Skip to content
Snippets Groups Projects
Commit 39c4684b authored by Harald Scheidl's avatar Harald Scheidl
Browse files

init

parents
No related branches found
No related tags found
No related merge requests found
data/words/
data/words.txt
src/__pycache__/
model/checkpoint
model/snapshot-*
Get IAM dataset
1. Register at: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
2. Download words.tgz
3. Download words.txt
4. Put words.txt into this dir
5. Create subdir words
6. Put content (directories a01, a02, ...) of words.tgz into subdir words
7. Run checkDirs.py for a rough check on the files
Check if dir structure looks like this:
data
--test.png
--words.txt
--words
----a01
------a01-000u
--------a01-000u-00-00.png
--------...
------...
----a02
----...
import os.path
checkDirs = ['words/', 'words/a01/a01-000u/']
checkFiles = ['words.txt', 'test.png', 'words/a01/a01-000u/a01-000u-00-00.png']
for f in checkDirs:
if os.path.isdir(f):
print(f, 'ok')
else:
print(f, 'not found!!!')
for f in checkFiles:
if os.path.isfile(f):
print(f, 'ok')
else:
print(f, 'not found!!!')
data/test.png

9.12 KiB

!"#&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
\ No newline at end of file
File added
import random
import numpy as np
import cv2
from SamplePreprocessor import preprocess
class Sample:
"sample from the dataset"
def __init__(self, gtText, filePath):
self.gtText = gtText
self.filePath = filePath
class Batch:
"batch containing images and ground truth texts"
def __init__(self, gtTexts, imgs):
self.imgs = np.stack(imgs, axis=0)
self.gtTexts = gtTexts
class DataLoader:
"loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
def __init__(self, filePath, batchSize, imgSize, maxTextLen):
"loader for dataset at given location, preprocess images and text according to parameters"
assert filePath[-1]=='/'
self.currIdx = 0
self.batchSize = batchSize
self.imgSize = imgSize
self.samples = []
f=open(filePath+'words.txt')
chars = set()
for line in f:
# ignore comment line
if not line or line[0]=='#':
continue
lineSplit = line.strip().split(' ')
assert len(lineSplit) >= 9
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
fileNameSplit = lineSplit[0].split('-')
fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + lineSplit[0] + '.png'
# GT text are columns starting at 9
gtText = ' '.join(lineSplit[8:])[:maxTextLen]
chars = chars.union(set(list(gtText)))
# put sample into list
self.samples.append(Sample(gtText, fileName))
# split into training and validation set: 95% - 5%
splitIdx = int(0.95 * len(self.samples))
self.trainSamples = self.samples[:splitIdx]
self.validationSamples = self.samples[splitIdx:]
# start with train set
self.trainSet()
# list of all chars in dataset
self.charList = sorted(list(chars))
def trainSet(self):
"switch to training set"
self.currIdx = 0
self.samples = self.trainSamples
def validationSet(self):
"switch to validation set"
self.currIdx = 0
self.samples = self.validationSamples
def shuffle(self):
"shuffle current set"
self.currIdx = 0
random.shuffle(self.samples)
def getIteratorInfo(self):
"current batch index and overall number of batches"
return (self.currIdx//self.batchSize, len(self.samples)//self.batchSize)
def hasNext(self):
"iterator"
return self.currIdx + self.batchSize <= len(self.samples)
def getNext(self):
"iterator"
batchRange = range(self.currIdx, self.currIdx + self.batchSize)
gtTexts = [self.samples[i].gtText for i in batchRange]
imgs = [preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize) for i in batchRange]
self.currIdx += self.batchSize
return Batch(gtTexts, imgs)
import sys
import tensorflow as tf
class Model:
"minimalistic TF model for HTR"
# model constants
batchSize = 50
imgSize = (128, 32)
maxTextLen = 32
def __init__(self, charList):
"init model: add CNN, RNN and CTC and initialize TF"
self.charList = charList
self.snapID = 0
# CNN
self.inputImgs = tf.placeholder(tf.float32, shape=(Model.batchSize, Model.imgSize[0], Model.imgSize[1]))
cnnOut4d = self.setupCNN(self.inputImgs)
# RNN
rnnOut3d = self.setupRNN(cnnOut4d)
# CTC
(self.loss, self.decoder) = self.setupCTC(rnnOut3d)
# optimizer for NN parameters
self.optimizer = tf.train.RMSPropOptimizer(0.001).minimize(self.loss)
# initialize TF
(self.sess, self.saver) = self.setupTF()
def setupCNN(self, cnnIn3d):
"create CNN layers and return output of these layers"
cnnIn4d = tf.expand_dims(input=cnnIn3d, axis=3)
# list of parameters for the layers
featureVals = [1, 32, 64, 128, 128, 256]
strideVals = poolVals = [(2,2), (2,2), (1,2), (1,2), (1,2)]
numLayers = len(strideVals)
# create layers
pool = cnnIn4d # input to first CNN layer
for i in range(numLayers):
k = 5
kernel = tf.Variable(tf.truncated_normal([k, k, featureVals[i], featureVals[i + 1]], stddev=0.1))
conv = tf.nn.conv2d(pool, kernel, padding='SAME', strides=(1,1,1,1))
relu = tf.nn.relu(conv)
pool = tf.nn.max_pool(relu, (1, poolVals[i][0], poolVals[i][1], 1), (1, strideVals[i][0], strideVals[i][1], 1), 'VALID')
return pool
def setupRNN(self, rnnIn4d):
"create RNN layers and return output of these layers"
rnnIn3d = tf.squeeze(rnnIn4d, axis=[2])
# basic cells which is used to build RNN
numHidden = 256
cells=[tf.contrib.rnn.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in range(2)] # 2 layers
# stack basic cells
stacked=tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)
# bidirectional RNN
# BxTxF -> BxTx2H
((fw,bw),_)=tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype)
# BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
concat=tf.expand_dims(tf.concat([fw,bw],2), 2)
# project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
kernel=tf.Variable(tf.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1))
return tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2])
def setupCTC(self, ctcIn3d):
"create CTC loss and decoder and return them"
# BxTxC -> TxBxC
ctcIn3dTBC = tf.transpose(ctcIn3d, [1, 0, 2])
# ground truth text as sparse tensor
self.gtTexts = tf.SparseTensor(tf.placeholder(tf.int64, shape=[None,2]) , tf.placeholder(tf.int32,[None]), tf.placeholder(tf.int64,[2]))
# calc loss for batch
self.seqLen = tf.placeholder(tf.int32, [None])
loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)
decoder = tf.nn.ctc_greedy_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen)
return (tf.reduce_mean(loss), decoder)
def setupTF(self):
"initialize TF"
print('Python: '+sys.version)
print('Tensorflow: '+tf.__version__)
sess=tf.Session() # TF session
saver = tf.train.Saver(max_to_keep=1) # saver saves model to file
latestSnapshot = tf.train.latest_checkpoint('../model/') # is there a saved model?
# no saved model -> init with new values
if not latestSnapshot:
print('Init with new values')
sess.run(tf.global_variables_initializer())
# init with saved values
else:
print('Init with stored values from ' + latestSnapshot)
saver.restore(sess, latestSnapshot)
return (sess,saver)
def toSparse(self, texts):
"put ground truth texts into sparse tensor for ctc_loss"
indices=[]
values=[]
shape=[len(texts), 0] # last entry must be max(labelList[i])
# go over all texts
for (batchElement, text) in enumerate(texts):
# convert to string of label (i.e. class-ids)
labelStr=[self.charList.index(c) for c in text]
# sparse tensor must have size of max. label-string
if len(labelStr) > shape[1]:
shape[1] = len(labelStr)
# put each label into sparse tensor
for (i, label) in enumerate(labelStr):
indices.append([batchElement, i])
values.append(label)
return (indices, values, shape)
def fromSparse(self, ctcOutput):
"extract texts from sparse tensor"
# ctc returns tuple, first element is SparseTensor
decoded=ctcOutput[0][0]
# go over all indices and save mapping: batch -> values
idxDict = { b : [] for b in range(Model.batchSize) }
encodedLabelStrs = [[] for i in range(Model.batchSize)]
for (idx, idx2d) in enumerate(decoded.indices):
label = decoded.values[idx]
batchElement = idx2d[0] # index according to [b,t]
encodedLabelStrs[batchElement].append(label)
# map labels to chars for all batch elements
return [str().join([self.charList[c] for c in labelStr]) for labelStr in encodedLabelStrs]
def trainBatch(self, batch):
"feed a batch into the NN to train it"
sparse = self.toSparse(batch.gtTexts)
(_, lossVal) = self.sess.run([self.optimizer, self.loss], { self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * Model.batchSize } )
return lossVal
def inferBatch(self, batch):
"feed a batch into the NN to recngnize the texts"
decoded = self.sess.run(self.decoder, { self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * Model.batchSize } )
return self.fromSparse(decoded)
def save(self):
"save model to file"
self.snapID += 1
self.saver.save(self.sess, '../model/snapshot', global_step=self.snapID)
import numpy as np
import cv2
def preprocess(img, imgSize):
# there are damaged files in IAM dataset - just use black image instead
if img is None:
img = np.zeros([1,1])
# create target image and copy sample image into it
(wt, ht) = imgSize
(h, w) = img.shape
fx = w / wt
fy = h / ht
f = max(fx, fy)
newSize = (min(wt, int(w / f)), min(ht, int(h / f)))
img = cv2.resize(img, newSize)
target = np.ones([ht, wt]) * 255
target[0:newSize[1], 0:newSize[0]] = img
# transpose for TF
img = cv2.transpose(target)
# normalize
(m,s)=cv2.meanStdDev(img)
m=m[0][0]
s=s[0][0]
img = img - m
img = img / s if s>0 else img
return img
import sys
import cv2
from DataLoader import DataLoader, Batch
from Model import Model
from SamplePreprocessor import preprocess
# filenames and paths to data
fnCharList = '../model/charList.txt'
fnTrain = '../data/'
fnInfer = '../data/test.png'
def train(filePath):
"train NN"
# load training data
loader = DataLoader(filePath, Model.batchSize, Model.imgSize, Model.maxTextLen)
# create TF model
model = Model(loader.charList)
# save characters of model for inference mode
open(fnCharList, 'w').write(str().join(loader.charList))
# train forever
epoch = 0
while True:
print('Epoch:', epoch)
model.save()
# train
print('Train NN')
loader.trainSet()
loader.shuffle()
while loader.hasNext():
iterInfo = loader.getIteratorInfo()
batch = loader.getNext()
loss = model.trainBatch(batch)
print('Iterator:', iterInfo, 'Loss:', loss)
# validate
print('Validate NN')
loader.validationSet()
numOK = 0
numTotal = 0
while loader.hasNext():
iterInfo = loader.getIteratorInfo()
print('Iterator:', iterInfo)
batch = loader.getNext()
loss = model.trainBatch(batch)
recognized = model.inferBatch(batch)
print('Ground truth -> Recognized')
for i in range(len(recognized)):
isOK = batch.gtTexts[i] == recognized[i]
print('[OK]' if isOK else '[ERR]','"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"')
numOK += 1 if isOK else 0
numTotal +=1
# print validation result
print('Correctly recognized words:', numOK / numTotal * 100.0, '%')
epoch += 1
def infer(filePath):
"recognize text in image provided by file path"
model = Model(open(fnCharList).read())
img = preprocess(cv2.imread(fnInfer, cv2.IMREAD_GRAYSCALE), Model.imgSize)
batch = Batch(None, [img] * Model.batchSize)
recognized = model.inferBatch(batch)
print('Recognized:', '"' + recognized[0] + '"')
if __name__ == '__main__':
if len(sys.argv) == 2 and sys.argv[1] == 'train':
train(fnTrain)
else:
infer(fnInfer)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment