diff --git a/.gitignore b/.gitignore index 1b9575d156b3cb7320b4dfc73039b72535140fbf..8efcba50d500e0da4d36918183250ec11358fc3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ -data/words/ -data/words.txt -data/words.7z +data/words* +data/words.txt* src/__pycache__/ model/checkpoint model/snapshot-* diff --git a/src/DataLoader.py b/src/DataLoaderIAM.py similarity index 74% rename from src/DataLoader.py rename to src/DataLoaderIAM.py index 002d5711525da4d5d5ee3931b6c4d9d305ef6733..7b4910f0323abf5f6ea7194d4bccc9ede85db674 100644 --- a/src/DataLoader.py +++ b/src/DataLoaderIAM.py @@ -1,138 +1,148 @@ -from __future__ import division -from __future__ import print_function - -import os -import random - -import cv2 -import numpy as np - -from SamplePreprocessor import preprocess - - -class Sample: - "sample from the dataset" - - def __init__(self, gtText, filePath): - self.gtText = gtText - self.filePath = filePath - - -class Batch: - "batch containing images and ground truth texts" - - def __init__(self, gtTexts, imgs): - self.imgs = np.stack(imgs, axis=0) - self.gtTexts = gtTexts - - -class DataLoader: - "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" - - def __init__(self, filePath, batchSize, imgSize, maxTextLen): - "loader for dataset at given location, preprocess images and text according to parameters" - - assert filePath[-1] == '/' - - self.dataAugmentation = False - self.currIdx = 0 - self.batchSize = batchSize - self.imgSize = imgSize - self.samples = [] - - f = open(filePath + 'words.txt') - chars = set() - bad_samples = [] - bad_samples_reference = ['a01-117-05-02.png', 'r06-022-03-05.png'] - for line in f: - # ignore comment line - if not line or line[0] == '#': - continue - - lineSplit = line.strip().split(' ') - assert len(lineSplit) >= 9 - - # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png - fileNameSplit = lineSplit[0].split('-') - fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + \ - lineSplit[0] + '.png' - - # GT text are columns starting at 9 - gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen) - chars = chars.union(set(list(gtText))) - - # check if image is not empty - if not os.path.getsize(fileName): - bad_samples.append(lineSplit[0] + '.png') - continue - - # put sample into list - self.samples.append(Sample(gtText, fileName)) - - # some images in the IAM dataset are known to be damaged, don't show warning for them - if set(bad_samples) != set(bad_samples_reference): - print("Warning, damaged images found:", bad_samples) - print("Damaged images expected:", bad_samples_reference) - - # split into training and validation set: 95% - 5% - splitIdx = int(0.95 * len(self.samples)) - self.trainSamples = self.samples[:splitIdx] - self.validationSamples = self.samples[splitIdx:] - - # put words into lists - self.trainWords = [x.gtText for x in self.trainSamples] - self.validationWords = [x.gtText for x in self.validationSamples] - - # number of randomly chosen samples per epoch for training - self.numTrainSamplesPerEpoch = 25000 - - # start with train set - self.trainSet() - - # list of all chars in dataset - self.charList = sorted(list(chars)) - - def truncateLabel(self, text, maxTextLen): - # ctc_loss can't compute loss if it cannot find a mapping between text label and input - # labels. Repeat letters cost double because of the blank symbol needing to be inserted. - # If a too-long label is provided, ctc_loss returns an infinite gradient - cost = 0 - for i in range(len(text)): - if i != 0 and text[i] == text[i - 1]: - cost += 2 - else: - cost += 1 - if cost > maxTextLen: - return text[:i] - return text - - def trainSet(self): - "switch to randomly chosen subset of training set" - self.dataAugmentation = True - self.currIdx = 0 - random.shuffle(self.trainSamples) - self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch] - - def validationSet(self): - "switch to validation set" - self.dataAugmentation = False - self.currIdx = 0 - self.samples = self.validationSamples - - def getIteratorInfo(self): - "current batch index and overall number of batches" - return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize) - - def hasNext(self): - "iterator" - return self.currIdx + self.batchSize <= len(self.samples) - - def getNext(self): - "iterator" - batchRange = range(self.currIdx, self.currIdx + self.batchSize) - gtTexts = [self.samples[i].gtText for i in batchRange] - imgs = [ - preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation) - for i in batchRange] - self.currIdx += self.batchSize - return Batch(gtTexts, imgs) +import pickle +import random + +import cv2 +import lmdb +import numpy as np +from path import Path + +from SamplePreprocessor import preprocess + + +class Sample: + "sample from the dataset" + + def __init__(self, gtText, filePath): + self.gtText = gtText + self.filePath = filePath + + +class Batch: + "batch containing images and ground truth texts" + + def __init__(self, gtTexts, imgs): + self.imgs = np.stack(imgs, axis=0) + self.gtTexts = gtTexts + + +class DataLoaderIAM: + "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" + + def __init__(self, data_dir, batchSize, imgSize, maxTextLen, fast=True): + "loader for dataset at given location, preprocess images and text according to parameters" + + assert data_dir.exists() + + self.fast = fast + if fast: + self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True) + + self.dataAugmentation = False + self.currIdx = 0 + self.batchSize = batchSize + self.imgSize = imgSize + self.samples = [] + + f = open(data_dir / 'gt/words.txt') + chars = set() + bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset + for line in f: + # ignore comment line + if not line or line[0] == '#': + continue + + lineSplit = line.strip().split(' ') + assert len(lineSplit) >= 9 + + # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png + fileNameSplit = lineSplit[0].split('-') + fileName = data_dir / 'img' / fileNameSplit[0] / f'{fileNameSplit[0]}-{fileNameSplit[1]}' / lineSplit[0] + '.png' + + if lineSplit[0] in bad_samples_reference: + print('Ignoring known broken image:', fileName) + continue + + # GT text are columns starting at 9 + gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen) + chars = chars.union(set(list(gtText))) + + # put sample into list + self.samples.append(Sample(gtText, fileName)) + + # split into training and validation set: 95% - 5% + splitIdx = int(0.95 * len(self.samples)) + self.trainSamples = self.samples[:splitIdx] + self.validationSamples = self.samples[splitIdx:] + + # put words into lists + self.trainWords = [x.gtText for x in self.trainSamples] + self.validationWords = [x.gtText for x in self.validationSamples] + + # number of randomly chosen samples per epoch for training + self.numTrainSamplesPerEpoch = 25000 + + # start with train set + self.trainSet() + + # list of all chars in dataset + self.charList = sorted(list(chars)) + + def truncateLabel(self, text, maxTextLen): + # ctc_loss can't compute loss if it cannot find a mapping between text label and input + # labels. Repeat letters cost double because of the blank symbol needing to be inserted. + # If a too-long label is provided, ctc_loss returns an infinite gradient + cost = 0 + for i in range(len(text)): + if i != 0 and text[i] == text[i - 1]: + cost += 2 + else: + cost += 1 + if cost > maxTextLen: + return text[:i] + return text + + def trainSet(self): + "switch to randomly chosen subset of training set" + self.dataAugmentation = True + self.currIdx = 0 + random.shuffle(self.trainSamples) + self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch] + + def validationSet(self): + "switch to validation set" + self.dataAugmentation = False + self.currIdx = 0 + self.samples = self.validationSamples + + def getIteratorInfo(self): + "current batch index and overall number of batches" + return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize) + + def hasNext(self): + "iterator" + return self.currIdx + self.batchSize <= len(self.samples) + + def getNext(self): + "iterator" + batchRange = range(self.currIdx, self.currIdx + self.batchSize) + gtTexts = [self.samples[i].gtText for i in batchRange] + + imgs = [] + for i in batchRange: + if self.fast: + with self.env.begin() as txn: + basename = Path(self.samples[i].filePath).basename() + data = txn.get(basename.encode("ascii")) + img = pickle.loads(data) + else: + img = cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE) + + imgs.append(preprocess(img, self.imgSize, self.dataAugmentation)) + + self.currIdx += self.batchSize + return Batch(gtTexts, imgs) + + +if __name__ == '__main__': + dl = DataLoaderIAM('../data/', 50, (128, 32), 32) + dl.getNext() diff --git a/src/Model.py b/src/Model.py index 07e300fd880eb93ba401eb05b792d764ae8bd2fc..1b58629e6c6c4703d599395548ef719ae8f39524 100644 --- a/src/Model.py +++ b/src/Model.py @@ -1,6 +1,3 @@ -from __future__ import division -from __future__ import print_function - import numpy as np import os import sys @@ -20,7 +17,6 @@ class Model: "minimalistic TF model for HTR" # model constants - batchSize = 50 imgSize = (128, 32) maxTextLen = 32 @@ -45,10 +41,9 @@ class Model: # setup optimizer to train NN self.batchesTrained = 0 - self.learningRate = tf.compat.v1.placeholder(tf.float32, shape=[]) self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) with tf.control_dependencies(self.update_ops): - self.optimizer = tf.compat.v1.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) + self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss) # initialize TF (self.sess, self.saver) = self.setupTF() @@ -221,12 +216,9 @@ class Model: "feed a batch into the NN to train it" numBatchElements = len(batch.imgs) sparse = self.toSparse(batch.gtTexts) - rate = 0.01 if self.batchesTrained < 10 else ( - 0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate evalList = [self.optimizer, self.loss] - feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, - self.seqLen: [Model.maxTextLen] * numBatchElements, self.learningRate: rate, self.is_train: True} - (_, lossVal) = self.sess.run(evalList, feedDict) + feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: True} + _, lossVal = self.sess.run(evalList, feedDict) self.batchesTrained += 1 return lossVal diff --git a/src/SamplePreprocessor.py b/src/SamplePreprocessor.py index 1d571e9f33a858c58021ef681996b19c2c306424..5547eecf4ddc6523f70179a29f06d608a9f53ab6 100644 --- a/src/SamplePreprocessor.py +++ b/src/SamplePreprocessor.py @@ -1,6 +1,3 @@ -from __future__ import division -from __future__ import print_function - import random import cv2 @@ -14,9 +11,25 @@ def preprocess(img, imgSize, dataAugmentation=False): if img is None: img = np.zeros([imgSize[1], imgSize[0]]) + img = img.astype(np.float) + # increase dataset size by applying random stretches to the images if dataAugmentation: - stretch = (random.random() - 0.5) # -0.5 .. +0.5 + if random.random() < 0.25: + rand_odd = lambda: random.randint(1, 3) * 2 + 1 + img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) + if random.random() < 0.25: + img = cv2.dilate(img,np.ones((3,3))) + if random.random() < 0.25: + img = cv2.erode(img,np.ones((3,3))) + if random.random() < 0.5: + img = img * (0.25 + random.random() * 0.75) + if random.random() < 0.25: + img = np.clip(img + (np.random.random(img.shape)-0.5) * random.randint(1, 50), 0, 255) + if random.random() < 0.1: + img = 255 - img + + stretch = random.random() - 0.5 # -0.5 .. +0.5 wStretched = max(int(img.shape[1] * (1 + stretch)), 1) # random width, but at least 1 img = cv2.resize(img, (wStretched, img.shape[0])) # stretch horizontally by factor 0.5 .. 1.5 @@ -29,16 +42,32 @@ def preprocess(img, imgSize, dataAugmentation=False): newSize = (max(min(wt, int(w / f)), 1), max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht) img = cv2.resize(img, newSize) - target = np.ones([ht, wt]) * 255 - target[0:newSize[1], 0:newSize[0]] = img + target = np.ones([ht, wt]) * 127.5 + + r_freedom = target.shape[0] - img.shape[0] + c_freedom = target.shape[1] - img.shape[1] + + if dataAugmentation: + r_off, c_off = random.randint(0, r_freedom), random.randint(0, c_freedom) + else: + r_off, c_off = r_freedom // 2, c_freedom // 2 + + target[r_off:img.shape[0]+r_off, c_off:img.shape[1]+c_off] = img # transpose for TF img = cv2.transpose(target) - # normalize - (m, s) = cv2.meanStdDev(img) - m = m[0][0] - s = s[0][0] - img = img - m - img = img / s if s > 0 else img + # convert to range [-1, 1] + img = img / 255 - 0.5 return img + + +if __name__ == '__main__': + import matplotlib.pyplot as plt + img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE) + img_aug = preprocess(img, (128, 32), True) + plt.subplot(121) + plt.imshow(img) + plt.subplot(122) + plt.imshow(cv2.transpose(img_aug)) + plt.show() \ No newline at end of file diff --git a/src/analyze.py b/src/analyze.py index d34a97b93cdf1944289fb94453f16ea27ef82e45..6658fd5f0bb99286d2e445d2c3e69e71c3667947 100644 --- a/src/analyze.py +++ b/src/analyze.py @@ -1,6 +1,3 @@ -from __future__ import division -from __future__ import print_function - import copy import math import pickle diff --git a/src/create_lmdb.py b/src/create_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..979571b103bbe2f27b2200409bb965ddf853d8b0 --- /dev/null +++ b/src/create_lmdb.py @@ -0,0 +1,27 @@ +import argparse +import pickle + +import cv2 +import lmdb +from path import Path + +parser = argparse.ArgumentParser() +parser.add_argument('--data_dir', type=Path, required=True) +args = parser.parse_args() + +# 2GB is enough for IAM dataset +assert not (args.data_dir / 'lmdb').exists() +env = lmdb.open(str(args.data_dir / 'lmdb'), map_size=1024 * 1024 * 1024 * 2) + +# go over all png files +fn_imgs = list((args.data_dir / 'img').walkfiles('*.png')) + +# and put the imgs into lmdb as pickled grayscale imgs +with env.begin(write=True) as txn: + for i, fn_img in enumerate(fn_imgs): + print(i, len(fn_imgs)) + img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE) + basename = fn_img.basename() + txn.put(basename.encode("ascii"), pickle.dumps(img)) + +env.close() diff --git a/src/main.py b/src/main.py index 50f54e06f717cbe68b53139da4dbfad652a65108..9e62aa2473aa027f1c5b474e76f44c57491b0e26 100644 --- a/src/main.py +++ b/src/main.py @@ -1,21 +1,18 @@ -from __future__ import division -from __future__ import print_function - import argparse import cv2 import editdistance -from DataLoader import DataLoader, Batch +from DataLoaderIAM import DataLoaderIAM, Batch from Model import Model, DecoderType from SamplePreprocessor import preprocess +from path import Path class FilePaths: "filenames and paths to data" fnCharList = '../model/charList.txt' fnAccuracy = '../model/accuracy.txt' - fnTrain = '../data/' fnInfer = '../data/test.png' fnCorpus = '../data/corpus.txt' @@ -25,7 +22,7 @@ def train(model, loader): epoch = 0 # number of training epochs since start bestCharErrorRate = float('inf') # best valdiation character error rate noImprovementSince = 0 # number of epochs no improvement of character error rate occured - earlyStopping = 5 # stop training after this number of epochs without improvement + earlyStopping = 25 # stop training after this number of epochs without improvement while True: epoch += 1 print('Epoch:', epoch) @@ -37,7 +34,7 @@ def train(model, loader): iterInfo = loader.getIteratorInfo() batch = loader.getNext() loss = model.trainBatch(batch) - print('Batch:', iterInfo[0], '/', iterInfo[1], 'Loss:', loss) + print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}') # validate charErrorRate = validate(model, loader) @@ -49,14 +46,14 @@ def train(model, loader): noImprovementSince = 0 model.save() open(FilePaths.fnAccuracy, 'w').write( - 'Validation character error rate of saved model: %f%%' % (charErrorRate * 100.0)) + f'Validation character error rate of saved model: {charErrorRate * 100.0}%') else: - print('Character error rate not improved') + print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%') noImprovementSince += 1 # stop training if no more improvement in the last x epochs if noImprovementSince >= earlyStopping: - print('No more improvement since %d epochs. Training stopped.' % earlyStopping) + print(f'No more improvement since {earlyStopping} epochs. Training stopped.') break @@ -70,7 +67,7 @@ def validate(model, loader): numWordTotal = 0 while loader.hasNext(): iterInfo = loader.getIteratorInfo() - print('Batch:', iterInfo[0], '/', iterInfo[1]) + print(f'Batch: {iterInfo[0]} / {iterInfo[1]}') batch = loader.getNext() (recognized, _) = model.inferBatch(batch) @@ -87,7 +84,7 @@ def validate(model, loader): # print validation result charErrorRate = numCharErr / numCharTotal wordAccuracy = numWordOK / numWordTotal - print('Character error rate: %f%%. Word accuracy: %f%%.' % (charErrorRate * 100.0, wordAccuracy * 100.0)) + print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.') return charErrorRate @@ -96,8 +93,8 @@ def infer(model, fnImg): img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize) batch = Batch(None, [img]) (recognized, probability) = model.inferBatch(batch, True) - print('Recognized:', '"' + recognized[0] + '"') - print('Probability:', probability[0]) + print(f'Recognized: "{recognized[0]}"') + print(f'Probability: {probability[0]}') def main(): @@ -110,6 +107,9 @@ def main(): parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true') parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true') + parser.add_argument('--fast', help='use lmdb to load images', action='store_true') + parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False) + parser.add_argument('--batch_size', help='batch size', type=int, default=100) args = parser.parse_args() @@ -122,7 +122,7 @@ def main(): # train or validate on IAM dataset if args.train or args.validate: # load training data, create TF model - loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen) + loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.imgSize, Model.maxTextLen, args.fast) # save characters of model for inference mode open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))