Skip to content
Snippets Groups Projects
Commit f3945bd7 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

reworked code: faster dataloader, more data augmentations, and some more changes...

parent baecb532
No related branches found
No related tags found
No related merge requests found
data/words/ data/words*
data/words.txt data/words.txt*
data/words.7z
src/__pycache__/ src/__pycache__/
model/checkpoint model/checkpoint
model/snapshot-* model/snapshot-*
......
from __future__ import division import pickle
from __future__ import print_function
import os
import random import random
import cv2 import cv2
import lmdb
import numpy as np import numpy as np
from path import Path
from SamplePreprocessor import preprocess from SamplePreprocessor import preprocess
...@@ -26,13 +25,17 @@ class Batch: ...@@ -26,13 +25,17 @@ class Batch:
self.gtTexts = gtTexts self.gtTexts = gtTexts
class DataLoader: class DataLoaderIAM:
"loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
def __init__(self, filePath, batchSize, imgSize, maxTextLen): def __init__(self, data_dir, batchSize, imgSize, maxTextLen, fast=True):
"loader for dataset at given location, preprocess images and text according to parameters" "loader for dataset at given location, preprocess images and text according to parameters"
assert filePath[-1] == '/' assert data_dir.exists()
self.fast = fast
if fast:
self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
self.dataAugmentation = False self.dataAugmentation = False
self.currIdx = 0 self.currIdx = 0
...@@ -40,10 +43,9 @@ class DataLoader: ...@@ -40,10 +43,9 @@ class DataLoader:
self.imgSize = imgSize self.imgSize = imgSize
self.samples = [] self.samples = []
f = open(filePath + 'words.txt') f = open(data_dir / 'gt/words.txt')
chars = set() chars = set()
bad_samples = [] bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset
bad_samples_reference = ['a01-117-05-02.png', 'r06-022-03-05.png']
for line in f: for line in f:
# ignore comment line # ignore comment line
if not line or line[0] == '#': if not line or line[0] == '#':
...@@ -54,26 +56,19 @@ class DataLoader: ...@@ -54,26 +56,19 @@ class DataLoader:
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
fileNameSplit = lineSplit[0].split('-') fileNameSplit = lineSplit[0].split('-')
fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + \ fileName = data_dir / 'img' / fileNameSplit[0] / f'{fileNameSplit[0]}-{fileNameSplit[1]}' / lineSplit[0] + '.png'
lineSplit[0] + '.png'
if lineSplit[0] in bad_samples_reference:
print('Ignoring known broken image:', fileName)
continue
# GT text are columns starting at 9 # GT text are columns starting at 9
gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen) gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen)
chars = chars.union(set(list(gtText))) chars = chars.union(set(list(gtText)))
# check if image is not empty
if not os.path.getsize(fileName):
bad_samples.append(lineSplit[0] + '.png')
continue
# put sample into list # put sample into list
self.samples.append(Sample(gtText, fileName)) self.samples.append(Sample(gtText, fileName))
# some images in the IAM dataset are known to be damaged, don't show warning for them
if set(bad_samples) != set(bad_samples_reference):
print("Warning, damaged images found:", bad_samples)
print("Damaged images expected:", bad_samples_reference)
# split into training and validation set: 95% - 5% # split into training and validation set: 95% - 5%
splitIdx = int(0.95 * len(self.samples)) splitIdx = int(0.95 * len(self.samples))
self.trainSamples = self.samples[:splitIdx] self.trainSamples = self.samples[:splitIdx]
...@@ -131,8 +126,23 @@ class DataLoader: ...@@ -131,8 +126,23 @@ class DataLoader:
"iterator" "iterator"
batchRange = range(self.currIdx, self.currIdx + self.batchSize) batchRange = range(self.currIdx, self.currIdx + self.batchSize)
gtTexts = [self.samples[i].gtText for i in batchRange] gtTexts = [self.samples[i].gtText for i in batchRange]
imgs = [
preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation) imgs = []
for i in batchRange] for i in batchRange:
if self.fast:
with self.env.begin() as txn:
basename = Path(self.samples[i].filePath).basename()
data = txn.get(basename.encode("ascii"))
img = pickle.loads(data)
else:
img = cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE)
imgs.append(preprocess(img, self.imgSize, self.dataAugmentation))
self.currIdx += self.batchSize self.currIdx += self.batchSize
return Batch(gtTexts, imgs) return Batch(gtTexts, imgs)
if __name__ == '__main__':
dl = DataLoaderIAM('../data/', 50, (128, 32), 32)
dl.getNext()
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
import os import os
import sys import sys
...@@ -20,7 +17,6 @@ class Model: ...@@ -20,7 +17,6 @@ class Model:
"minimalistic TF model for HTR" "minimalistic TF model for HTR"
# model constants # model constants
batchSize = 50
imgSize = (128, 32) imgSize = (128, 32)
maxTextLen = 32 maxTextLen = 32
...@@ -45,10 +41,9 @@ class Model: ...@@ -45,10 +41,9 @@ class Model:
# setup optimizer to train NN # setup optimizer to train NN
self.batchesTrained = 0 self.batchesTrained = 0
self.learningRate = tf.compat.v1.placeholder(tf.float32, shape=[])
self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops): with tf.control_dependencies(self.update_ops):
self.optimizer = tf.compat.v1.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss)
# initialize TF # initialize TF
(self.sess, self.saver) = self.setupTF() (self.sess, self.saver) = self.setupTF()
...@@ -221,12 +216,9 @@ class Model: ...@@ -221,12 +216,9 @@ class Model:
"feed a batch into the NN to train it" "feed a batch into the NN to train it"
numBatchElements = len(batch.imgs) numBatchElements = len(batch.imgs)
sparse = self.toSparse(batch.gtTexts) sparse = self.toSparse(batch.gtTexts)
rate = 0.01 if self.batchesTrained < 10 else (
0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
evalList = [self.optimizer, self.loss] evalList = [self.optimizer, self.loss]
feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: True}
self.seqLen: [Model.maxTextLen] * numBatchElements, self.learningRate: rate, self.is_train: True} _, lossVal = self.sess.run(evalList, feedDict)
(_, lossVal) = self.sess.run(evalList, feedDict)
self.batchesTrained += 1 self.batchesTrained += 1
return lossVal return lossVal
......
from __future__ import division
from __future__ import print_function
import random import random
import cv2 import cv2
...@@ -14,9 +11,25 @@ def preprocess(img, imgSize, dataAugmentation=False): ...@@ -14,9 +11,25 @@ def preprocess(img, imgSize, dataAugmentation=False):
if img is None: if img is None:
img = np.zeros([imgSize[1], imgSize[0]]) img = np.zeros([imgSize[1], imgSize[0]])
img = img.astype(np.float)
# increase dataset size by applying random stretches to the images # increase dataset size by applying random stretches to the images
if dataAugmentation: if dataAugmentation:
stretch = (random.random() - 0.5) # -0.5 .. +0.5 if random.random() < 0.25:
rand_odd = lambda: random.randint(1, 3) * 2 + 1
img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
if random.random() < 0.25:
img = cv2.dilate(img,np.ones((3,3)))
if random.random() < 0.25:
img = cv2.erode(img,np.ones((3,3)))
if random.random() < 0.5:
img = img * (0.25 + random.random() * 0.75)
if random.random() < 0.25:
img = np.clip(img + (np.random.random(img.shape)-0.5) * random.randint(1, 50), 0, 255)
if random.random() < 0.1:
img = 255 - img
stretch = random.random() - 0.5 # -0.5 .. +0.5
wStretched = max(int(img.shape[1] * (1 + stretch)), 1) # random width, but at least 1 wStretched = max(int(img.shape[1] * (1 + stretch)), 1) # random width, but at least 1
img = cv2.resize(img, (wStretched, img.shape[0])) # stretch horizontally by factor 0.5 .. 1.5 img = cv2.resize(img, (wStretched, img.shape[0])) # stretch horizontally by factor 0.5 .. 1.5
...@@ -29,16 +42,32 @@ def preprocess(img, imgSize, dataAugmentation=False): ...@@ -29,16 +42,32 @@ def preprocess(img, imgSize, dataAugmentation=False):
newSize = (max(min(wt, int(w / f)), 1), newSize = (max(min(wt, int(w / f)), 1),
max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht) max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht)
img = cv2.resize(img, newSize) img = cv2.resize(img, newSize)
target = np.ones([ht, wt]) * 255 target = np.ones([ht, wt]) * 127.5
target[0:newSize[1], 0:newSize[0]] = img
r_freedom = target.shape[0] - img.shape[0]
c_freedom = target.shape[1] - img.shape[1]
if dataAugmentation:
r_off, c_off = random.randint(0, r_freedom), random.randint(0, c_freedom)
else:
r_off, c_off = r_freedom // 2, c_freedom // 2
target[r_off:img.shape[0]+r_off, c_off:img.shape[1]+c_off] = img
# transpose for TF # transpose for TF
img = cv2.transpose(target) img = cv2.transpose(target)
# normalize # convert to range [-1, 1]
(m, s) = cv2.meanStdDev(img) img = img / 255 - 0.5
m = m[0][0]
s = s[0][0]
img = img - m
img = img / s if s > 0 else img
return img return img
if __name__ == '__main__':
import matplotlib.pyplot as plt
img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
img_aug = preprocess(img, (128, 32), True)
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(cv2.transpose(img_aug))
plt.show()
\ No newline at end of file
from __future__ import division
from __future__ import print_function
import copy import copy
import math import math
import pickle import pickle
......
import argparse
import pickle
import cv2
import lmdb
from path import Path
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=Path, required=True)
args = parser.parse_args()
# 2GB is enough for IAM dataset
assert not (args.data_dir / 'lmdb').exists()
env = lmdb.open(str(args.data_dir / 'lmdb'), map_size=1024 * 1024 * 1024 * 2)
# go over all png files
fn_imgs = list((args.data_dir / 'img').walkfiles('*.png'))
# and put the imgs into lmdb as pickled grayscale imgs
with env.begin(write=True) as txn:
for i, fn_img in enumerate(fn_imgs):
print(i, len(fn_imgs))
img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
basename = fn_img.basename()
txn.put(basename.encode("ascii"), pickle.dumps(img))
env.close()
from __future__ import division
from __future__ import print_function
import argparse import argparse
import cv2 import cv2
import editdistance import editdistance
from DataLoader import DataLoader, Batch from DataLoaderIAM import DataLoaderIAM, Batch
from Model import Model, DecoderType from Model import Model, DecoderType
from SamplePreprocessor import preprocess from SamplePreprocessor import preprocess
from path import Path
class FilePaths: class FilePaths:
"filenames and paths to data" "filenames and paths to data"
fnCharList = '../model/charList.txt' fnCharList = '../model/charList.txt'
fnAccuracy = '../model/accuracy.txt' fnAccuracy = '../model/accuracy.txt'
fnTrain = '../data/'
fnInfer = '../data/test.png' fnInfer = '../data/test.png'
fnCorpus = '../data/corpus.txt' fnCorpus = '../data/corpus.txt'
...@@ -25,7 +22,7 @@ def train(model, loader): ...@@ -25,7 +22,7 @@ def train(model, loader):
epoch = 0 # number of training epochs since start epoch = 0 # number of training epochs since start
bestCharErrorRate = float('inf') # best valdiation character error rate bestCharErrorRate = float('inf') # best valdiation character error rate
noImprovementSince = 0 # number of epochs no improvement of character error rate occured noImprovementSince = 0 # number of epochs no improvement of character error rate occured
earlyStopping = 5 # stop training after this number of epochs without improvement earlyStopping = 25 # stop training after this number of epochs without improvement
while True: while True:
epoch += 1 epoch += 1
print('Epoch:', epoch) print('Epoch:', epoch)
...@@ -37,7 +34,7 @@ def train(model, loader): ...@@ -37,7 +34,7 @@ def train(model, loader):
iterInfo = loader.getIteratorInfo() iterInfo = loader.getIteratorInfo()
batch = loader.getNext() batch = loader.getNext()
loss = model.trainBatch(batch) loss = model.trainBatch(batch)
print('Batch:', iterInfo[0], '/', iterInfo[1], 'Loss:', loss) print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}')
# validate # validate
charErrorRate = validate(model, loader) charErrorRate = validate(model, loader)
...@@ -49,14 +46,14 @@ def train(model, loader): ...@@ -49,14 +46,14 @@ def train(model, loader):
noImprovementSince = 0 noImprovementSince = 0
model.save() model.save()
open(FilePaths.fnAccuracy, 'w').write( open(FilePaths.fnAccuracy, 'w').write(
'Validation character error rate of saved model: %f%%' % (charErrorRate * 100.0)) f'Validation character error rate of saved model: {charErrorRate * 100.0}%')
else: else:
print('Character error rate not improved') print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%')
noImprovementSince += 1 noImprovementSince += 1
# stop training if no more improvement in the last x epochs # stop training if no more improvement in the last x epochs
if noImprovementSince >= earlyStopping: if noImprovementSince >= earlyStopping:
print('No more improvement since %d epochs. Training stopped.' % earlyStopping) print(f'No more improvement since {earlyStopping} epochs. Training stopped.')
break break
...@@ -70,7 +67,7 @@ def validate(model, loader): ...@@ -70,7 +67,7 @@ def validate(model, loader):
numWordTotal = 0 numWordTotal = 0
while loader.hasNext(): while loader.hasNext():
iterInfo = loader.getIteratorInfo() iterInfo = loader.getIteratorInfo()
print('Batch:', iterInfo[0], '/', iterInfo[1]) print(f'Batch: {iterInfo[0]} / {iterInfo[1]}')
batch = loader.getNext() batch = loader.getNext()
(recognized, _) = model.inferBatch(batch) (recognized, _) = model.inferBatch(batch)
...@@ -87,7 +84,7 @@ def validate(model, loader): ...@@ -87,7 +84,7 @@ def validate(model, loader):
# print validation result # print validation result
charErrorRate = numCharErr / numCharTotal charErrorRate = numCharErr / numCharTotal
wordAccuracy = numWordOK / numWordTotal wordAccuracy = numWordOK / numWordTotal
print('Character error rate: %f%%. Word accuracy: %f%%.' % (charErrorRate * 100.0, wordAccuracy * 100.0)) print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.')
return charErrorRate return charErrorRate
...@@ -96,8 +93,8 @@ def infer(model, fnImg): ...@@ -96,8 +93,8 @@ def infer(model, fnImg):
img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize) img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize)
batch = Batch(None, [img]) batch = Batch(None, [img])
(recognized, probability) = model.inferBatch(batch, True) (recognized, probability) = model.inferBatch(batch, True)
print('Recognized:', '"' + recognized[0] + '"') print(f'Recognized: "{recognized[0]}"')
print('Probability:', probability[0]) print(f'Probability: {probability[0]}')
def main(): def main():
...@@ -110,6 +107,9 @@ def main(): ...@@ -110,6 +107,9 @@ def main():
parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding',
action='store_true') action='store_true')
parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true') parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
parser.add_argument('--batch_size', help='batch size', type=int, default=100)
args = parser.parse_args() args = parser.parse_args()
...@@ -122,7 +122,7 @@ def main(): ...@@ -122,7 +122,7 @@ def main():
# train or validate on IAM dataset # train or validate on IAM dataset
if args.train or args.validate: if args.train or args.validate:
# load training data, create TF model # load training data, create TF model
loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen) loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.imgSize, Model.maxTextLen, args.fast)
# save characters of model for inference mode # save characters of model for inference mode
open(FilePaths.fnCharList, 'w').write(str().join(loader.charList)) open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment