Skip to content
Snippets Groups Projects
Commit 67859b2c authored by Harald Scheidl's avatar Harald Scheidl
Browse files

use random subset of training set for faster epochs

parent df7cb0a8
No related branches found
No related tags found
No related merge requests found
# Handwritten Text Recognition with TensorFlow # Handwritten Text Recognition with TensorFlow
Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset. Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset.
This Neural Network (NN) model is trained to recognize segmented words as shown in the illustration below. This Neural Network (NN) model recognizes the text contained in the images of segmented words as shown in the illustration below.
As the images of segmented words are smaller than images of complete text-lines, the NN can be kept small and therefore training on the CPU is feasible. As these word-images are smaller than images of complete text-lines, the NN can be kept small and training on the CPU is feasible.
I will give some hints in how to extend the model in case you need larger input-images or want a better recognition accuracy. I will give some hints how to extend the model in case you need larger input-images or want better recognition accuracy.
![img](./doc/htr.png) ![img](./doc/htr.png)
...@@ -47,17 +47,20 @@ The expected output is shown below. ...@@ -47,17 +47,20 @@ The expected output is shown below.
``` ```
> python main.py train > python main.py train
Init with stored values from ../model/snapshot-1 Init with new values
Epoch: 0 Epoch: 1
Train NN Train NN
Batch: 0 / 2191 Loss: 3.87954 Batch: 1 / 500 Loss: 113.333
Batch: 1 / 2191 Loss: 5.31012 Batch: 2 / 500 Loss: 40.0665
Batch: 2 / 2191 Loss: 3.87662 Batch: 3 / 500 Loss: 24.2433
Batch: 3 / 2191 Loss: 4.03646 Batch: 4 / 500 Loss: 21.644
Batch: 5 / 500 Loss: 22.2018
Batch: 6 / 500 Loss: 18.6628
Batch: 7 / 500 Loss: 20.9978
... ...
Validate NN Validate NN
Batch: 0 / 115 Batch: 1 / 115
Ground truth -> Recognized Ground truth -> Recognized
[OK] "," -> "," [OK] "," -> ","
[ERR] "Di" -> "D" [ERR] "Di" -> "D"
...@@ -67,7 +70,7 @@ Ground truth -> Recognized ...@@ -67,7 +70,7 @@ Ground truth -> Recognized
[OK] "told" -> "told" [OK] "told" -> "told"
[OK] "her" -> "her" [OK] "her" -> "her"
... ...
Correctly recognized words: 71.0 % Correctly recognized words: 67.0608695652174 %
``` ```
### Other datasets ### Other datasets
...@@ -79,11 +82,11 @@ Either you convert your dataset to the IAM format (look at `words.txt` and the c ...@@ -79,11 +82,11 @@ Either you convert your dataset to the IAM format (look at `words.txt` and the c
### Overview ### Overview
The model is a stripped-down version of the HTR system I used for my thesis. The model is a stripped-down version of the HTR system I implemented for my thesis.
What remains is what I think is the bare minimum to recognize text with an acceptable accuracy. What remains is what I think is the bare minimum to recognize text with an acceptable accuracy.
The implementation only depends on numpy, cv2 and tensorflow imports. The implementation only depends on numpy, cv2 and tensorflow imports.
It consists of 5 CNN layers, 2 RNN (LSTM) layers and the CTC loss and decoding layer. It consists of 5 CNN layers, 2 RNN (LSTM) layers and the CTC loss and decoding layer.
The illustration below gives an overview of the NN (green: operations, pink: data) and here follows a short description: The illustration below gives an overview of the NN (green: operations, pink: data flowing through NN) and here follows a short description:
* The input image is a gray-value image and has a size of 128x32 * The input image is a gray-value image and has a size of 128x32
* 5 CNN layers map the input image to a feature sequence of size 32x256 * 5 CNN layers map the input image to a feature sequence of size 32x256
...@@ -99,11 +102,11 @@ The illustration below gives an overview of the NN (green: operations, pink: dat ...@@ -99,11 +102,11 @@ The illustration below gives an overview of the NN (green: operations, pink: dat
Around 70% of the words from IAM are correctly recognized by the NN. Around 70% of the words from IAM are correctly recognized by the NN.
If you need a better accuracy, here are some ideas on how to improve it: If you need a better accuracy, here are some ideas on how to improve it:
* Data augmentation: increase dataset-size by applying random transformations to the input images * Data augmentation: increase dataset-size by applying random transformations to the input images. At the moment, only random distortions are performed
* Remove cursive writing style in the input images (see [DeslantImg](https://github.com/githubharald/DeslantImg)) * Remove cursive writing style in the input images (see [DeslantImg](https://github.com/githubharald/DeslantImg))
* Increase input size (if input of NN is large enough, complete text-lines can be used) * Increase input size (if input of NN is large enough, complete text-lines can be used)
* Add more CNN layers * Add more CNN layers
* Replace LSTM by MD-LSTM * Replace LSTM by multidimensional LSTM
* Decoder: either use vanilla beam search decoding (included with TF) or use word beam search decoding (see [CTCWordBeamSearch](https://github.com/githubharald/CTCWordBeamSearch)) to constrain the output to dictionary words * Decoder: either use vanilla beam search decoding (included with TF) or use word beam search decoding (see [CTCWordBeamSearch](https://github.com/githubharald/CTCWordBeamSearch)) to constrain the output to dictionary words
* Text correction: if the recognized word is not contained in a dictionary, search for the most similar one * Text correction: if the recognized word is not contained in a dictionary, search for the most similar one
......
Validation accuracy of saved model: 0.6664347826086956
\ No newline at end of file
...@@ -53,12 +53,14 @@ class DataLoader: ...@@ -53,12 +53,14 @@ class DataLoader:
# put sample into list # put sample into list
self.samples.append(Sample(gtText, fileName)) self.samples.append(Sample(gtText, fileName))
# split into training and validation set: 95% - 5% # split into training and validation set: 95% - 5%
splitIdx = int(0.95 * len(self.samples)) splitIdx = int(0.95 * len(self.samples))
self.trainSamples = self.samples[:splitIdx] self.trainSamples = self.samples[:splitIdx]
self.validationSamples = self.samples[splitIdx:] self.validationSamples = self.samples[splitIdx:]
# number of randomly chosen samples per epoch for training
self.numTrainSamplesPerEpoch = 25000
# start with train set # start with train set
self.trainSet() self.trainSet()
...@@ -67,10 +69,11 @@ class DataLoader: ...@@ -67,10 +69,11 @@ class DataLoader:
def trainSet(self): def trainSet(self):
"switch to training set" "switch to randomly chosen subset of training set"
self.dataAugmentation = True self.dataAugmentation = True
self.currIdx = 0 self.currIdx = 0
self.samples = self.trainSamples random.shuffle(self.trainSamples)
self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch]
def validationSet(self): def validationSet(self):
...@@ -80,12 +83,6 @@ class DataLoader: ...@@ -80,12 +83,6 @@ class DataLoader:
self.samples = self.validationSamples self.samples = self.validationSamples
def shuffle(self):
"shuffle current set"
self.currIdx = 0
random.shuffle(self.samples)
def getIteratorInfo(self): def getIteratorInfo(self):
"current batch index and overall number of batches" "current batch index and overall number of batches"
return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize) return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize)
......
...@@ -7,6 +7,7 @@ from SamplePreprocessor import preprocess ...@@ -7,6 +7,7 @@ from SamplePreprocessor import preprocess
# filenames and paths to data # filenames and paths to data
fnCharList = '../model/charList.txt' fnCharList = '../model/charList.txt'
fnAccuracy = '../model/accuracy.txt'
fnTrain = '../data/' fnTrain = '../data/'
fnInfer = '../data/test.png' fnInfer = '../data/test.png'
...@@ -34,7 +35,6 @@ def train(filePath): ...@@ -34,7 +35,6 @@ def train(filePath):
# train # train
print('Train NN') print('Train NN')
loader.trainSet() loader.trainSet()
loader.shuffle()
while loader.hasNext(): while loader.hasNext():
iterInfo = loader.getIteratorInfo() iterInfo = loader.getIteratorInfo()
batch = loader.getNext() batch = loader.getNext()
...@@ -70,6 +70,7 @@ def train(filePath): ...@@ -70,6 +70,7 @@ def train(filePath):
bestAccuracy = accuracy bestAccuracy = accuracy
noImprovementSince = 0 noImprovementSince = 0
model.save() model.save()
open(fnAccuracy, 'w').write('Validation accuracy of saved model: '+str(accuracy))
else: else:
print('Accuracy not improved') print('Accuracy not improved')
noImprovementSince += 1 noImprovementSince += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment