Skip to content
Snippets Groups Projects
Commit a451e5f0 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

validation: removed trainBatch

parent d2044457
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset. Handwritten Text Recognition (HTR) system implemented with TensorFlow (TF) and trained on the IAM off-line HTR dataset.
This Neural Network (NN) model recognizes the text contained in the images of segmented words as shown in the illustration below. This Neural Network (NN) model recognizes the text contained in the images of segmented words as shown in the illustration below.
As these word-images are smaller than images of complete text-lines, the NN can be kept small and training on the CPU is feasible. As these word-images are smaller than images of complete text-lines, the NN can be kept small and training on the CPU is feasible.
More than 86% of the samples from the validation-set are correctly recognized. More than 70% of the samples from the validation-set are correctly recognized.
I will give some hints how to extend the model in case you need larger input-images or want better recognition accuracy. I will give some hints how to extend the model in case you need larger input-images or want better recognition accuracy.
![img](./doc/htr.png) ![img](./doc/htr.png)
...@@ -83,7 +83,7 @@ Ground truth -> Recognized ...@@ -83,7 +83,7 @@ Ground truth -> Recognized
[OK] "told" -> "told" [OK] "told" -> "told"
[OK] "her" -> "her" [OK] "her" -> "her"
... ...
Correctly recognized words: 86.34782608695653 % Correctly recognized words: 71.70434782608696 %
``` ```
### Other datasets ### Other datasets
...@@ -112,7 +112,7 @@ The illustration below gives an overview of the NN (green: operations, pink: dat ...@@ -112,7 +112,7 @@ The illustration below gives an overview of the NN (green: operations, pink: dat
### Improve accuracy ### Improve accuracy
Around 86% of the words from IAM are correctly recognized by the NN. Around 71% of the words from IAM are correctly recognized by the NN.
If you need a better accuracy, here are some ideas on how to improve it: If you need a better accuracy, here are some ideas on how to improve it:
* Data augmentation: increase dataset-size by applying random transformations to the input images. At the moment, only random distortions are performed * Data augmentation: increase dataset-size by applying random transformations to the input images. At the moment, only random distortions are performed
......
...@@ -89,7 +89,7 @@ class Model: ...@@ -89,7 +89,7 @@ class Model:
loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True) loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)
# decoder: either best path decoding or beam search decoding # decoder: either best path decoding or beam search decoding
if self.useBeamSearch: if self.useBeamSearch:
decoder = tf.nn.ctc_beam_search_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen, beam_width=25, merge_repeated=False) decoder = tf.nn.ctc_beam_search_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=False)
else: else:
decoder = tf.nn.ctc_greedy_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen) decoder = tf.nn.ctc_greedy_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen)
return (tf.reduce_mean(loss), decoder) return (tf.reduce_mean(loss), decoder)
......
...@@ -52,7 +52,6 @@ def train(filePath): ...@@ -52,7 +52,6 @@ def train(filePath):
iterInfo = loader.getIteratorInfo() iterInfo = loader.getIteratorInfo()
print('Batch:', iterInfo[0],'/', iterInfo[1]) print('Batch:', iterInfo[0],'/', iterInfo[1])
batch = loader.getNext() batch = loader.getNext()
loss = model.trainBatch(batch)
recognized = model.inferBatch(batch) recognized = model.inferBatch(batch)
print('Ground truth -> Recognized') print('Ground truth -> Recognized')
...@@ -102,7 +101,6 @@ def validate(filePath): ...@@ -102,7 +101,6 @@ def validate(filePath):
iterInfo = loader.getIteratorInfo() iterInfo = loader.getIteratorInfo()
print('Batch:', iterInfo[0],'/', iterInfo[1]) print('Batch:', iterInfo[0],'/', iterInfo[1])
batch = loader.getNext() batch = loader.getNext()
loss = model.trainBatch(batch)
recognized = model.inferBatch(batch) recognized = model.inferBatch(batch)
print('Ground truth -> Recognized') print('Ground truth -> Recognized')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment