diff --git a/README.md b/README.md index 543404f6357f92073675669968e31234b162f524..516dd79998df64b6240416ec1e43d4bd9c37b39e 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Follow these instructions to get the IAM dataset: ### Fast image loading Loading and decoding the png image files from the disk is the bottleneck even when using only a small GPU. The database LMDB is used to speed up image loading: -* Go to the `src` directory and run `createLMDB.py --data_dir path/to/IAM` with the IAM data directory specified +* Go to the `src` directory and run `create_lmdb.py --data_dir path/to/IAM` with the IAM data directory specified * A subfolder `lmdb` is created in the IAM data directory containing the LMDB files * When training the model, add the command line option `--fast` diff --git a/src/DataLoaderIAM.py b/src/DataLoaderIAM.py deleted file mode 100644 index a95eb0f5719085b92a91d02c83b273e3b25441af..0000000000000000000000000000000000000000 --- a/src/DataLoaderIAM.py +++ /dev/null @@ -1,155 +0,0 @@ -import pickle -import random - -import cv2 -import lmdb -import numpy as np -from path import Path - -from SamplePreprocessor import preprocess - - -class Sample: - "sample from the dataset" - - def __init__(self, gtText, filePath): - self.gtText = gtText - self.filePath = filePath - - -class Batch: - "batch containing images and ground truth texts" - - def __init__(self, gtTexts, imgs): - self.imgs = np.stack(imgs, axis=0) - self.gtTexts = gtTexts - - -class DataLoaderIAM: - "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" - - def __init__(self, data_dir, batchSize, imgSize, maxTextLen, fast=True): - "loader for dataset at given location, preprocess images and text according to parameters" - - assert data_dir.exists() - - self.fast = fast - if fast: - self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True) - - self.dataAugmentation = False - self.currIdx = 0 - self.batchSize = batchSize - self.imgSize = imgSize - self.samples = [] - - f = open(data_dir / 'gt/words.txt') - chars = set() - bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset - for line in f: - # ignore comment line - if not line or line[0] == '#': - continue - - lineSplit = line.strip().split(' ') - assert len(lineSplit) >= 9 - - # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png - fileNameSplit = lineSplit[0].split('-') - fileName = data_dir / 'img' / fileNameSplit[0] / f'{fileNameSplit[0]}-{fileNameSplit[1]}' / lineSplit[0] + '.png' - - if lineSplit[0] in bad_samples_reference: - print('Ignoring known broken image:', fileName) - continue - - # GT text are columns starting at 9 - gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen) - chars = chars.union(set(list(gtText))) - - # put sample into list - self.samples.append(Sample(gtText, fileName)) - - # split into training and validation set: 95% - 5% - splitIdx = int(0.95 * len(self.samples)) - self.trainSamples = self.samples[:splitIdx] - self.validationSamples = self.samples[splitIdx:] - - # put words into lists - self.trainWords = [x.gtText for x in self.trainSamples] - self.validationWords = [x.gtText for x in self.validationSamples] - - # start with train set - self.trainSet() - - # list of all chars in dataset - self.charList = sorted(list(chars)) - - def truncateLabel(self, text, maxTextLen): - # ctc_loss can't compute loss if it cannot find a mapping between text label and input - # labels. Repeat letters cost double because of the blank symbol needing to be inserted. - # If a too-long label is provided, ctc_loss returns an infinite gradient - cost = 0 - for i in range(len(text)): - if i != 0 and text[i] == text[i - 1]: - cost += 2 - else: - cost += 1 - if cost > maxTextLen: - return text[:i] - return text - - def trainSet(self): - "switch to randomly chosen subset of training set" - self.dataAugmentation = True - self.currIdx = 0 - random.shuffle(self.trainSamples) - self.samples = self.trainSamples - self.currSet = 'train' - - def validationSet(self): - "switch to validation set" - self.dataAugmentation = False - self.currIdx = 0 - self.samples = self.validationSamples - self.currSet = 'val' - - def getIteratorInfo(self): - "current batch index and overall number of batches" - if self.currSet == 'train': - numBatches = int(np.floor(len(self.samples) / self.batchSize)) # train set: only full-sized batches - else: - numBatches = int(np.ceil(len(self.samples) / self.batchSize)) # val set: allow last batch to be smaller - currBatch = self.currIdx // self.batchSize + 1 - return currBatch, numBatches - - def hasNext(self): - "iterator" - if self.currSet == 'train': - return self.currIdx + self.batchSize <= len(self.samples) # train set: only full-sized batches - else: - return self.currIdx < len(self.samples) # val set: allow last batch to be smaller - - def getNext(self): - "iterator" - batchRange = range(self.currIdx, min(self.currIdx + self.batchSize, len(self.samples))) - gtTexts = [self.samples[i].gtText for i in batchRange] - - imgs = [] - for i in batchRange: - if self.fast: - with self.env.begin() as txn: - basename = Path(self.samples[i].filePath).basename() - data = txn.get(basename.encode("ascii")) - img = pickle.loads(data) - else: - img = cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE) - - imgs.append(preprocess(img, self.imgSize, self.dataAugmentation)) - - self.currIdx += self.batchSize - return Batch(gtTexts, imgs) - - -if __name__ == '__main__': - dl = DataLoaderIAM('../data/', 50, (128, 32), 32) - dl.getNext() diff --git a/src/Model.py b/src/Model.py deleted file mode 100644 index fcacde7f75f91dadb97e800b9c0f5cf52162d4ba..0000000000000000000000000000000000000000 --- a/src/Model.py +++ /dev/null @@ -1,295 +0,0 @@ -import os -import sys - -import numpy as np -import tensorflow as tf - -# Disable eager mode -tf.compat.v1.disable_eager_execution() - - -class DecoderType: - BestPath = 0 - BeamSearch = 1 - WordBeamSearch = 2 - - -class Model: - "minimalistic TF model for HTR" - - # model constants - imgSize = (128, 32) - maxTextLen = 32 - - def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False, dump=False): - "init model: add CNN, RNN and CTC and initialize TF" - self.dump = dump - self.charList = charList - self.decoderType = decoderType - self.mustRestore = mustRestore - self.snapID = 0 - - # Whether to use normalization over a batch or a population - self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') - - # input image batch - self.inputImgs = tf.compat.v1.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1])) - - # setup CNN, RNN and CTC - self.setupCNN() - self.setupRNN() - self.setupCTC() - - # setup optimizer to train NN - self.batchesTrained = 0 - self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) - with tf.control_dependencies(self.update_ops): - self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss) - - # initialize TF - (self.sess, self.saver) = self.setupTF() - - def setupCNN(self): - "create CNN layers and return output of these layers" - cnnIn4d = tf.expand_dims(input=self.inputImgs, axis=3) - - # list of parameters for the layers - kernelVals = [5, 5, 3, 3, 3] - featureVals = [1, 32, 64, 128, 128, 256] - strideVals = poolVals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)] - numLayers = len(strideVals) - - # create layers - pool = cnnIn4d # input to first CNN layer - for i in range(numLayers): - kernel = tf.Variable( - tf.random.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], - stddev=0.1)) - conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1)) - conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train) - relu = tf.nn.relu(conv_norm) - pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1), - strides=(1, strideVals[i][0], strideVals[i][1], 1), padding='VALID') - - self.cnnOut4d = pool - - def setupRNN(self): - "create RNN layers and return output of these layers" - rnnIn3d = tf.squeeze(self.cnnOut4d, axis=[2]) - - # basic cells which is used to build RNN - numHidden = 256 - cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in - range(2)] # 2 layers - - # stack basic cells - stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) - - # bidirectional RNN - # BxTxF -> BxTx2H - ((fw, bw), _) = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, - dtype=rnnIn3d.dtype) - - # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H - concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) - - # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC - kernel = tf.Variable(tf.random.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1)) - self.rnnOut3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2]) - - def setupCTC(self): - "create CTC loss and decoder and return them" - # BxTxC -> TxBxC - self.ctcIn3dTBC = tf.transpose(a=self.rnnOut3d, perm=[1, 0, 2]) - # ground truth text as sparse tensor - self.gtTexts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]), - tf.compat.v1.placeholder(tf.int32, [None]), - tf.compat.v1.placeholder(tf.int64, [2])) - - # calc loss for batch - self.seqLen = tf.compat.v1.placeholder(tf.int32, [None]) - self.loss = tf.reduce_mean(input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, - sequence_length=self.seqLen, - ctc_merge_repeated=True)) - - # calc loss for each element to compute label probability - self.savedCtcInput = tf.compat.v1.placeholder(tf.float32, - shape=[Model.maxTextLen, None, len(self.charList) + 1]) - self.lossPerElement = tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, - sequence_length=self.seqLen, ctc_merge_repeated=True) - - # best path decoding or beam search decoding - if self.decoderType == DecoderType.BestPath: - self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen) - elif self.decoderType == DecoderType.BeamSearch: - self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, - beam_width=50) - # word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch) - elif self.decoderType == DecoderType.WordBeamSearch: - # prepare information about language (dictionary, characters in dataset, characters forming words) - chars = str().join(self.charList) - wordChars = open('../model/wordCharList.txt').read().splitlines()[0] - corpus = open('../data/corpus.txt').read() - - # decode using the "Words" mode of word beam search - from word_beam_search import WordBeamSearch - self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), - wordChars.encode('utf8')) - - # the input to the decoder must have softmax already applied - self.wbsInput = tf.nn.softmax(self.ctcIn3dTBC, axis=2) - - def setupTF(self): - "initialize TF" - print('Python: ' + sys.version) - print('Tensorflow: ' + tf.__version__) - - sess = tf.compat.v1.Session() # TF session - - saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file - modelDir = '../model/' - latestSnapshot = tf.train.latest_checkpoint(modelDir) # is there a saved model? - - # if model must be restored (for inference), there must be a snapshot - if self.mustRestore and not latestSnapshot: - raise Exception('No saved model found in: ' + modelDir) - - # load saved model if available - if latestSnapshot: - print('Init with stored values from ' + latestSnapshot) - saver.restore(sess, latestSnapshot) - else: - print('Init with new values') - sess.run(tf.compat.v1.global_variables_initializer()) - - return (sess, saver) - - def toSparse(self, texts): - "put ground truth texts into sparse tensor for ctc_loss" - indices = [] - values = [] - shape = [len(texts), 0] # last entry must be max(labelList[i]) - - # go over all texts - for (batchElement, text) in enumerate(texts): - # convert to string of label (i.e. class-ids) - labelStr = [self.charList.index(c) for c in text] - # sparse tensor must have size of max. label-string - if len(labelStr) > shape[1]: - shape[1] = len(labelStr) - # put each label into sparse tensor - for (i, label) in enumerate(labelStr): - indices.append([batchElement, i]) - values.append(label) - - return (indices, values, shape) - - def decoderOutputToText(self, ctcOutput, batchSize): - "extract texts from output of CTC decoder" - - # word beam search: already contains label strings - if self.decoderType == DecoderType.WordBeamSearch: - labelStrs = ctcOutput - - # TF decoders: label strings are contained in sparse tensor - else: - # ctc returns tuple, first element is SparseTensor - decoded = ctcOutput[0][0] - - # contains string of labels for each batch element - labelStrs = [[] for _ in range(batchSize)] - - # go over all indices and save mapping: batch -> values - for (idx, idx2d) in enumerate(decoded.indices): - label = decoded.values[idx] - batchElement = idx2d[0] # index according to [b,t] - labelStrs[batchElement].append(label) - - # map labels to chars for all batch elements - return [str().join([self.charList[c] for c in labelStr]) for labelStr in labelStrs] - - def trainBatch(self, batch): - "feed a batch into the NN to train it" - numBatchElements = len(batch.imgs) - sparse = self.toSparse(batch.gtTexts) - evalList = [self.optimizer, self.loss] - feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, - self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: True} - _, lossVal = self.sess.run(evalList, feedDict) - self.batchesTrained += 1 - return lossVal - - def dumpNNOutput(self, rnnOutput): - "dump the output of the NN to CSV file(s)" - dumpDir = '../dump/' - if not os.path.isdir(dumpDir): - os.mkdir(dumpDir) - - # iterate over all batch elements and create a CSV file for each one - maxT, maxB, maxC = rnnOutput.shape - for b in range(maxB): - csv = '' - for t in range(maxT): - for c in range(maxC): - csv += str(rnnOutput[t, b, c]) + ';' - csv += '\n' - fn = dumpDir + 'rnnOutput_' + str(b) + '.csv' - print('Write dump of NN to file: ' + fn) - with open(fn, 'w') as f: - f.write(csv) - - def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False): - "feed a batch into the NN to recognize the texts" - - # decode, optionally save RNN output - numBatchElements = len(batch.imgs) - - # put tensors to be evaluated into list - evalList = [] - - if self.decoderType == DecoderType.WordBeamSearch: - evalList.append(self.wbsInput) - else: - evalList.append(self.decoder) - - if self.dump or calcProbability: - evalList.append(self.ctcIn3dTBC) - - # dict containing all tensor fed into the model - feedDict = {self.inputImgs: batch.imgs, self.seqLen: [Model.maxTextLen] * numBatchElements, - self.is_train: False} - - # evaluate model - evalRes = self.sess.run(evalList, feedDict) - - # TF decoders: decoding already done in TF graph - if self.decoderType != DecoderType.WordBeamSearch: - decoded = evalRes[0] - # word beam search decoder: decoding is done in C++ function compute() - else: - decoded = self.decoder.compute(evalRes[0]) - - # map labels (numbers) to character string - texts = self.decoderOutputToText(decoded, numBatchElements) - - # feed RNN output and recognized text into CTC loss to compute labeling probability - probs = None - if calcProbability: - sparse = self.toSparse(batch.gtTexts) if probabilityOfGT else self.toSparse(texts) - ctcInput = evalRes[1] - evalList = self.lossPerElement - feedDict = {self.savedCtcInput: ctcInput, self.gtTexts: sparse, - self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: False} - lossVals = self.sess.run(evalList, feedDict) - probs = np.exp(-lossVals) - - # dump the output of the NN to CSV file(s) - if self.dump: - self.dumpNNOutput(evalRes[1]) - - return texts, probs - - def save(self): - "save model to file" - self.snapID += 1 - self.saver.save(self.sess, '../model/snapshot', global_step=self.snapID) diff --git a/src/createLMDB.py b/src/create_lmdb.py similarity index 100% rename from src/createLMDB.py rename to src/create_lmdb.py diff --git a/src/dataloader_iam.py b/src/dataloader_iam.py new file mode 100644 index 0000000000000000000000000000000000000000..654708a1dbc382349cbd478265133cc9e127ca45 --- /dev/null +++ b/src/dataloader_iam.py @@ -0,0 +1,142 @@ +import pickle +import random +from collections import namedtuple + +import cv2 +import lmdb +import numpy as np +from path import Path + +from preprocess import preprocess + +Sample = namedtuple('Sample', 'gt_text, file_path') +Batch = namedtuple('Batch', 'gt_texts, imgs') + + +class DataLoaderIAM: + "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" + + def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True): + """Loader for dataset.""" + + assert data_dir.exists() + + self.fast = fast + if fast: + self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True) + + self.data_augmentation = False + self.curr_idx = 0 + self.batch_size = batch_size + self.img_size = img_size + self.samples = [] + + f = open(data_dir / 'gt/words.txt') + chars = set() + bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset + for line in f: + # ignore comment line + if not line or line[0] == '#': + continue + + line_split = line.strip().split(' ') + assert len(line_split) >= 9 + + # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png + file_name_split = line_split[0].split('-') + file_name = data_dir / 'img' / file_name_split[0] / f'{file_name_split[0]}-{file_name_split[1]}' / \ + line_split[0] + '.png' + + if line_split[0] in bad_samples_reference: + print('Ignoring known broken image:', file_name) + continue + + # GT text are columns starting at 9 + gt_text = self.truncate_label(' '.join(line_split[8:]), max_text_len) + chars = chars.union(set(list(gt_text))) + + # put sample into list + self.samples.append(Sample(gt_text, file_name)) + + # split into training and validation set: 95% - 5% + split_idx = int(0.95 * len(self.samples)) + self.train_samples = self.samples[:split_idx] + self.validation_samples = self.samples[split_idx:] + + # put words into lists + self.train_words = [x.gt_text for x in self.train_samples] + self.validation_words = [x.gt_text for x in self.validation_samples] + + # start with train set + self.train_set() + + # list of all chars in dataset + self.char_list = sorted(list(chars)) + + @staticmethod + def truncate_label(text, max_text_len): + """ + Function ctc_loss can't compute loss if it cannot find a mapping between text label and input + labels. Repeat letters cost double because of the blank symbol needing to be inserted. + If a too-long label is provided, ctc_loss returns an infinite gradient. + """ + cost = 0 + for i in range(len(text)): + if i != 0 and text[i] == text[i - 1]: + cost += 2 + else: + cost += 1 + if cost > max_text_len: + return text[:i] + return text + + def train_set(self): + """Switch to randomly chosen subset of training set.""" + self.data_augmentation = True + self.curr_idx = 0 + random.shuffle(self.train_samples) + self.samples = self.train_samples + self.curr_set = 'train' + + def validation_set(self): + "switch to validation set" + self.data_augmentation = False + self.curr_idx = 0 + self.samples = self.validation_samples + self.curr_set = 'val' + + def get_iterator_info(self): + """Current batch index and overall number of batches.""" + if self.curr_set == 'train': + num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches + else: + num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller + curr_batch = self.curr_idx // self.batch_size + 1 + return curr_batch, num_batches + + def has_next(self): + "iterator" + if self.curr_set == 'train': + return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches + else: + return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller + + def get_next(self): + "iterator" + batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) + gt_texts = [self.samples[i].gt_text for i in batch_range] + + imgs = [] + for i in batch_range: + if self.fast: + with self.env.begin() as txn: + basename = Path(self.samples[i].file_path).basename() + data = txn.get(basename.encode("ascii")) + img = pickle.loads(data) + else: + img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) + + imgs.append(preprocess(img, self.img_size, self.data_augmentation)) + + self.curr_idx += self.batch_size + return Batch(gt_texts, imgs) diff --git a/src/main.py b/src/main.py index c6fc2229d7059007f2106fcdde4bdced20a7cd18..492cb80eebc59c95f1d0ab6d437ce1a119032fdc 100644 --- a/src/main.py +++ b/src/main.py @@ -5,111 +5,111 @@ import cv2 import editdistance from path import Path -from DataLoaderIAM import DataLoaderIAM, Batch -from Model import Model, DecoderType -from SamplePreprocessor import preprocess +from dataloader_iam import DataLoaderIAM, Batch +from model import Model, DecoderType +from preprocess import preprocess class FilePaths: - "filenames and paths to data" - fnCharList = '../model/charList.txt' - fnSummary = '../model/summary.json' - fnInfer = '../data/test.png' - fnCorpus = '../data/corpus.txt' + """Filenames and paths to data.""" + fn_char_list = '../model/charList.txt' + fn_summary = '../model/summary.json' + fn_infer = '../data/test.png' + fn_corpus = '../data/corpus.txt' -def write_summary(charErrorRates, wordAccuracies): - with open(FilePaths.fnSummary, 'w') as f: - json.dump({'charErrorRates': charErrorRates, 'wordAccuracies': wordAccuracies}, f) +def write_summary(char_error_rates, word_accuracies): + with open(FilePaths.fn_summary, 'w') as f: + json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) def train(model, loader): - "train NN" + """Trains NN.""" epoch = 0 # number of training epochs since start - summaryCharErrorRates = [] - summaryWordAccuracies = [] - bestCharErrorRate = float('inf') # best valdiation character error rate - noImprovementSince = 0 # number of epochs no improvement of character error rate occured - earlyStopping = 25 # stop training after this number of epochs without improvement + summary_char_error_rates = [] + summary_word_accuracies = [] + best_char_error_rate = float('inf') # best valdiation character error rate + no_improvement_since = 0 # number of epochs no improvement of character error rate occured + early_stopping = 25 # stop training after this number of epochs without improvement while True: epoch += 1 print('Epoch:', epoch) # train print('Train NN') - loader.trainSet() - while loader.hasNext(): - iterInfo = loader.getIteratorInfo() - batch = loader.getNext() + loader.train_set() + while loader.has_next(): + iter_info = loader.get_iterator_info() + batch = loader.get_next() loss = model.trainBatch(batch) - print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}') + print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}') # validate - charErrorRate, wordAccuracy = validate(model, loader) + char_error_rate, word_accuracy = validate(model, loader) # write summary - summaryCharErrorRates.append(charErrorRate) - summaryWordAccuracies.append(wordAccuracy) - write_summary(summaryCharErrorRates, summaryWordAccuracies) + summary_char_error_rates.append(char_error_rate) + summary_word_accuracies.append(word_accuracy) + write_summary(summary_char_error_rates, summary_word_accuracies) # if best validation accuracy so far, save model parameters - if charErrorRate < bestCharErrorRate: + if char_error_rate < best_char_error_rate: print('Character error rate improved, save model') - bestCharErrorRate = charErrorRate - noImprovementSince = 0 + best_char_error_rate = char_error_rate + no_improvement_since = 0 model.save() else: - print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%') - noImprovementSince += 1 + print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%') + no_improvement_since += 1 # stop training if no more improvement in the last x epochs - if noImprovementSince >= earlyStopping: - print(f'No more improvement since {earlyStopping} epochs. Training stopped.') + if no_improvement_since >= early_stopping: + print(f'No more improvement since {early_stopping} epochs. Training stopped.') break def validate(model, loader): - "validate NN" + """Validates NN.""" print('Validate NN') - loader.validationSet() - numCharErr = 0 - numCharTotal = 0 - numWordOK = 0 - numWordTotal = 0 - while loader.hasNext(): - iterInfo = loader.getIteratorInfo() - print(f'Batch: {iterInfo[0]} / {iterInfo[1]}') - batch = loader.getNext() - (recognized, _) = model.inferBatch(batch) + loader.validation_set() + num_char_err = 0 + num_char_total = 0 + num_word_ok = 0 + num_word_total = 0 + while loader.has_next(): + iter_info = loader.get_iterator_info() + print(f'Batch: {iter_info[0]} / {iter_info[1]}') + batch = loader.get_next() + (recognized, _) = model.infer_batch(batch) print('Ground truth -> Recognized') for i in range(len(recognized)): - numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0 - numWordTotal += 1 - dist = editdistance.eval(recognized[i], batch.gtTexts[i]) - numCharErr += dist - numCharTotal += len(batch.gtTexts[i]) - print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gtTexts[i] + '"', '->', + num_word_ok += 1 if batch.gt_texts[i] == recognized[i] else 0 + num_word_total += 1 + dist = editdistance.eval(recognized[i], batch.gt_texts[i]) + num_char_err += dist + num_char_total += len(batch.gt_texts[i]) + print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->', '"' + recognized[i] + '"') # print validation result - charErrorRate = numCharErr / numCharTotal - wordAccuracy = numWordOK / numWordTotal - print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.') - return charErrorRate, wordAccuracy + char_error_rate = num_char_err / num_char_total + word_accuracy = num_word_ok / num_word_total + print(f'Character error rate: {char_error_rate * 100.0}%. Word accuracy: {word_accuracy * 100.0}%.') + return char_error_rate, word_accuracy -def infer(model, fnImg): - "recognize text in image provided by file path" - img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize) +def infer(model, fn_img): + """Recognizes text in image provided by file path.""" + img = preprocess(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE), Model.img_size) batch = Batch(None, [img]) - (recognized, probability) = model.inferBatch(batch, True) + (recognized, probability) = model.infer_batch(batch, True) print(f'Recognized: "{recognized[0]}"') print(f'Probability: {probability[0]}') def main(): - "main function" + """Main function.""" parser = argparse.ArgumentParser() parser.add_argument('--train', help='train the NN', action='store_true') parser.add_argument('--validate', help='validate the NN', action='store_true') @@ -122,36 +122,34 @@ def main(): args = parser.parse_args() # set chosen CTC decoder - if args.decoder == 'bestpath': - decoderType = DecoderType.BestPath - elif args.decoder == 'beamsearch': - decoderType = DecoderType.BeamSearch - elif args.decoder == 'wordbeamsearch': - decoderType = DecoderType.WordBeamSearch + decoder_mapping = {'bestpath': DecoderType.BestPath, + 'beamsearch': DecoderType.BeamSearch, + 'wordbeamsearch': DecoderType.WordBeamSearch} + decoder_type = decoder_mapping[args.decoder] # train or validate on IAM dataset if args.train or args.validate: # load training data, create TF model - loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.imgSize, Model.maxTextLen, args.fast) + loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast) # save characters of model for inference mode - open(FilePaths.fnCharList, 'w').write(str().join(loader.charList)) + open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list)) # save words contained in dataset into file - open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords)) + open(FilePaths.fn_corpus, 'w').write(str(' ').join(loader.train_words + loader.validation_words)) # execute training or validation if args.train: - model = Model(loader.charList, decoderType) + model = Model(loader.char_list, decoder_type) train(model, loader) elif args.validate: - model = Model(loader.charList, decoderType, mustRestore=True) + model = Model(loader.char_list, decoder_type, must_restore=True) validate(model, loader) # infer text on test image else: - model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump) - infer(model, FilePaths.fnInfer) + model = Model(open(FilePaths.fn_char_list).read(), decoder_type, must_restore=True, dump=args.dump) + infer(model, FilePaths.fn_infer) if __name__ == '__main__': diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..63652c7a74ded766e9d0e1ce30fa6642117edc77 --- /dev/null +++ b/src/model.py @@ -0,0 +1,299 @@ +import os +import sys + +import numpy as np +import tensorflow as tf + +# Disable eager mode +tf.compat.v1.disable_eager_execution() + + +class DecoderType: + """CTC decoder types.""" + BestPath = 0 + BeamSearch = 1 + WordBeamSearch = 2 + + +class Model: + """Minimalistic TF model for HTR.""" + + # model constants + img_size = (128, 32) + max_text_len = 32 + + def __init__(self, char_list, decoder_type=DecoderType.BestPath, must_restore=False, dump=False): + """Init model: add CNN, RNN and CTC and initialize TF.""" + self.dump = dump + self.char_list = char_list + self.decoder_type = decoder_type + self.must_restore = must_restore + self.snap_ID = 0 + + # Whether to use normalization over a batch or a population + self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') + + # input image batch + self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, Model.img_size[0], Model.img_size[1])) + + # setup CNN, RNN and CTC + self.setup_cnn() + self.setup_rnn() + self.setup_ctc() + + # setup optimizer to train NN + self.batches_trained = 0 + self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(self.update_ops): + self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss) + + # initialize TF + self.sess, self.saver = self.setup_tf() + + def setup_cnn(self): + """Create CNN layers.""" + cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3) + + # list of parameters for the layers + kernel_vals = [5, 5, 3, 3, 3] + feature_vals = [1, 32, 64, 128, 128, 256] + stride_vals = poolVals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)] + num_layers = len(stride_vals) + + # create layers + pool = cnn_in4d # input to first CNN layer + for i in range(num_layers): + kernel = tf.Variable( + tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]], + stddev=0.1)) + conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1)) + conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train) + relu = tf.nn.relu(conv_norm) + pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1), + strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID') + + self.cnn_out_4d = pool + + def setup_rnn(self): + """Create RNN layers.""" + rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2]) + + # basic cells which is used to build RNN + num_hidden = 256 + cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in + range(2)] # 2 layers + + # stack basic cells + stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) + + # bidirectional RNN + # BxTxF -> BxTx2H + (fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d, + dtype=rnn_in3d.dtype) + + # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H + concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) + + # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC + kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1)) + self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), + axis=[2]) + + def setup_ctc(self): + """Create CTC loss and decoder.""" + # BxTxC -> TxBxC + self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2]) + # ground truth text as sparse tensor + self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]), + tf.compat.v1.placeholder(tf.int32, [None]), + tf.compat.v1.placeholder(tf.int64, [2])) + + # calc loss for batch + self.seq_len = tf.compat.v1.placeholder(tf.int32, [None]) + self.loss = tf.reduce_mean( + input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc, + sequence_length=self.seq_len, + ctc_merge_repeated=True)) + + # calc loss for each element to compute label probability + self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32, + shape=[Model.max_text_len, None, len(self.char_list) + 1]) + self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input, + sequence_length=self.seq_len, ctc_merge_repeated=True) + + # best path decoding or beam search decoding + if self.decoder_type == DecoderType.BestPath: + self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len) + elif self.decoder_type == DecoderType.BeamSearch: + self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len, + beam_width=50) + # word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch) + elif self.decoder_type == DecoderType.WordBeamSearch: + # prepare information about language (dictionary, characters in dataset, characters forming words) + chars = str().join(self.char_list) + word_chars = open('../model/wordCharList.txt').read().splitlines()[0] + corpus = open('../data/corpus.txt').read() + + # decode using the "Words" mode of word beam search + from word_beam_search import WordBeamSearch + self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), + word_chars.encode('utf8')) + + # the input to the decoder must have softmax already applied + self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2) + + def setup_tf(self): + """Initialize TF.""" + print('Python: ' + sys.version) + print('Tensorflow: ' + tf.__version__) + + sess = tf.compat.v1.Session() # TF session + + saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file + model_dir = '../model/' + latest_snapshot = tf.train.latest_checkpoint(model_dir) # is there a saved model? + + # if model must be restored (for inference), there must be a snapshot + if self.must_restore and not latest_snapshot: + raise Exception('No saved model found in: ' + model_dir) + + # load saved model if available + if latest_snapshot: + print('Init with stored values from ' + latest_snapshot) + saver.restore(sess, latest_snapshot) + else: + print('Init with new values') + sess.run(tf.compat.v1.global_variables_initializer()) + + return sess, saver + + def to_sparse(self, texts): + """Put ground truth texts into sparse tensor for ctc_loss.""" + indices = [] + values = [] + shape = [len(texts), 0] # last entry must be max(labelList[i]) + + # go over all texts + for (batchElement, text) in enumerate(texts): + # convert to string of label (i.e. class-ids) + label_str = [self.char_list.index(c) for c in text] + # sparse tensor must have size of max. label-string + if len(label_str) > shape[1]: + shape[1] = len(label_str) + # put each label into sparse tensor + for i, label in enumerate(label_str): + indices.append([batchElement, i]) + values.append(label) + + return indices, values, shape + + def decoder_output_to_text(self, ctc_output, batch_size): + """Extract texts from output of CTC decoder.""" + + # word beam search: already contains label strings + if self.decoder_type == DecoderType.WordBeamSearch: + label_strs = ctc_output + + # TF decoders: label strings are contained in sparse tensor + else: + # ctc returns tuple, first element is SparseTensor + decoded = ctc_output[0][0] + + # contains string of labels for each batch element + label_strs = [[] for _ in range(batch_size)] + + # go over all indices and save mapping: batch -> values + for (idx, idx2d) in enumerate(decoded.indices): + label = decoded.values[idx] + batch_element = idx2d[0] # index according to [b,t] + label_strs[batch_element].append(label) + + # map labels to chars for all batch elements + return [str().join([self.char_list[c] for c in labelStr]) for labelStr in label_strs] + + def trainBatch(self, batch): + """Feed a batch into the NN to train it.""" + num_batch_elements = len(batch.imgs) + sparse = self.to_sparse(batch.gt_texts) + eval_list = [self.optimizer, self.loss] + feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse, + self.seq_len: [Model.max_text_len] * num_batch_elements, self.is_train: True} + _, loss_val = self.sess.run(eval_list, feed_dict) + self.batches_trained += 1 + return loss_val + + @staticmethod + def dump_nn_output(rnn_output): + """Dump the output of the NN to CSV file(s).""" + dump_dir = '../dump/' + if not os.path.isdir(dump_dir): + os.mkdir(dump_dir) + + # iterate over all batch elements and create a CSV file for each one + max_t, max_b, max_c = rnn_output.shape + for b in range(max_b): + csv = '' + for t in range(max_t): + for c in range(max_c): + csv += str(rnn_output[t, b, c]) + ';' + csv += '\n' + fn = dump_dir + 'rnnOutput_' + str(b) + '.csv' + print('Write dump of NN to file: ' + fn) + with open(fn, 'w') as f: + f.write(csv) + + def infer_batch(self, batch, calc_probability=False, probability_of_gt=False): + """Feed a batch into the NN to recognize the texts.""" + + # decode, optionally save RNN output + num_batch_elements = len(batch.imgs) + + # put tensors to be evaluated into list + eval_list = [] + + if self.decoder_type == DecoderType.WordBeamSearch: + eval_list.append(self.wbs_input) + else: + eval_list.append(self.decoder) + + if self.dump or calc_probability: + eval_list.append(self.ctc_in_3d_tbc) + + # dict containing all tensor fed into the model + feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [Model.max_text_len] * num_batch_elements, + self.is_train: False} + + # evaluate model + eval_res = self.sess.run(eval_list, feed_dict) + + # TF decoders: decoding already done in TF graph + if self.decoder_type != DecoderType.WordBeamSearch: + decoded = eval_res[0] + # word beam search decoder: decoding is done in C++ function compute() + else: + decoded = self.decoder.compute(eval_res[0]) + + # map labels (numbers) to character string + texts = self.decoder_output_to_text(decoded, num_batch_elements) + + # feed RNN output and recognized text into CTC loss to compute labeling probability + probs = None + if calc_probability: + sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts) + ctc_input = eval_res[1] + eval_list = self.loss_per_element + feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse, + self.seq_len: [Model.max_text_len] * num_batch_elements, self.is_train: False} + loss_vals = self.sess.run(eval_list, feed_dict) + probs = np.exp(-loss_vals) + + # dump the output of the NN to CSV file(s) + if self.dump: + self.dump_nn_output(eval_res[1]) + + return texts, probs + + def save(self): + """Save model to file.""" + self.snap_ID += 1 + self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID) diff --git a/src/SamplePreprocessor.py b/src/preprocess.py similarity index 82% rename from src/SamplePreprocessor.py rename to src/preprocess.py index 590b735d47433d56500a9e7063b2ec88559cdfe7..272713682bb74743e06eee5907bedc8070314a60 100644 --- a/src/SamplePreprocessor.py +++ b/src/preprocess.py @@ -4,16 +4,16 @@ import cv2 import numpy as np -def preprocess(img, imgSize, dataAugmentation=False): +def preprocess(img, img_size, data_augmentation=False): "put img into target img of size imgSize, transpose for TF and normalize gray-values" # there are damaged files in IAM dataset - just use black image instead if img is None: - img = np.zeros(imgSize[::-1]) + img = np.zeros(img_size[::-1]) # data augmentation img = img.astype(np.float) - if dataAugmentation: + if data_augmentation: # photometric data augmentation if random.random() < 0.25: rand_odd = lambda: random.randint(1, 3) * 2 + 1 @@ -30,7 +30,7 @@ def preprocess(img, imgSize, dataAugmentation=False): img = 255 - img # geometric data augmentation - wt, ht = imgSize + wt, ht = img_size h, w = img.shape f = min(wt / w, ht / h) fx = f * np.random.uniform(0.75, 1.25) @@ -46,13 +46,13 @@ def preprocess(img, imgSize, dataAugmentation=False): # map image into target image M = np.float32([[fx, 0, tx], [0, fy, ty]]) - target = np.ones(imgSize[::-1]) * 255 / 2 - img = cv2.warpAffine(img, M, dsize=imgSize, dst=target, borderMode=cv2.BORDER_TRANSPARENT) + target = np.ones(img_size[::-1]) * 255 / 2 + img = cv2.warpAffine(img, M, dsize=img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) # no data augmentation else: # center image - wt, ht = imgSize + wt, ht = img_size h, w = img.shape f = min(wt / w, ht / h) tx = (wt - w * f) / 2 @@ -60,8 +60,8 @@ def preprocess(img, imgSize, dataAugmentation=False): # map image into target image M = np.float32([[f, 0, tx], [0, f, ty]]) - target = np.ones(imgSize[::-1]) * 255 / 2 - img = cv2.warpAffine(img, M, dsize=imgSize, dst=target, borderMode=cv2.BORDER_TRANSPARENT) + target = np.ones(img_size[::-1]) * 255 / 2 + img = cv2.warpAffine(img, M, dsize=img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) # transpose for TF img = cv2.transpose(img)