Skip to content
Snippets Groups Projects
Commit 97192a92 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

integrate word beam search decoding

parent da4c197a
No related branches found
No related tags found
No related merge requests found
......@@ -3,3 +3,4 @@ data/words.txt
src/__pycache__/
model/checkpoint
model/snapshot-*
*.so
......@@ -30,12 +30,31 @@ Recognized: "little"
* `--train`: train the NN, details see below.
* `--validate`: validate the NN, details see below.
* `--beamsearch`: use beam search decoding (better, but slower) instead of best path decoding.
* `--beamsearch`: use (vanilla) beam search decoding (better, but slower) instead of best path decoding.
* `--wordbeamsearch`: use word beam search decoding (only outputs words contained in a dictionary) instead of best path decoding. This is a custom TF operation and must be compiled from source, more information see corresponding section below. It should **not** be used when training the NN.
If neither `--train` nor `--validate` is specified, the NN infers the text from the test image (`data/test.png`).
Two examples: if you want to infer using beam search, execute `python main.py --beamsearch`, while you have to execute `python main.py --train --beamsearch` if you want to train the NN and do the validation using beam search.
## Integrate word beam search decoding
Besides the two decoders shipped with TF, it is possible to use word beam search decoding.
Using this decoder, the recognized words are constrained to those contained in a dictionary, but arbitrary non-word character strings (numbers, punctuation marks) are still possible.
Follow these instructions to integrate word beam search decoding:
1. Clone repository [CTCWordBeamSearch](https://github.com/githubharald/CTCWordBeamSearch).
2. Compile TF custom operation (follow instructions given in README).
3. Copy binary `TFWordBeamSearch.so` from the CTCWordBeamSearch repository to the `src/` directory of the SimpleHTR repository.
Word beam search can now be enabled by setting the corresponding command line argument.
The dictionary is created (in training and validation mode) by using all words contained in the IAM dataset (i.e. also including words from validation set) and is saved into the file `data/corpus.txt`.
Further, the list of word-characters can be found in the file `wordCharList.txt`.
Beam width is set to 50 to conform with the beam width of vanilla beam search decoding.
Using this configuration, a character error rate of 10% and a word accuracy of 81% is achieved.
## Train model
### IAM dataset
......
This diff is collapsed.
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
......@@ -58,6 +58,10 @@ class DataLoader:
self.trainSamples = self.samples[:splitIdx]
self.validationSamples = self.samples[splitIdx:]
# put words into lists
self.trainWords = [x.gtText for x in self.trainSamples]
self.validationWords = [x.gtText for x in self.validationSamples]
# number of randomly chosen samples per epoch for training
self.numTrainSamplesPerEpoch = 25000
......
......@@ -2,6 +2,12 @@ import sys
import tensorflow as tf
class DecoderType:
BestPath = 0
BeamSearch = 1
WordBeamSearch = 2
class Model:
"minimalistic TF model for HTR"
......@@ -10,10 +16,10 @@ class Model:
imgSize = (128, 32)
maxTextLen = 32
def __init__(self, charList, useBeamSearch=False, mustRestore=False):
def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False):
"init model: add CNN, RNN and CTC and initialize TF"
self.charList = charList
self.useBeamSearch = useBeamSearch
self.decoderType = decoderType
self.mustRestore = mustRestore
self.snapID = 0
......@@ -90,10 +96,23 @@ class Model:
self.seqLen = tf.placeholder(tf.int32, [None])
loss = tf.nn.ctc_loss(labels=self.gtTexts, inputs=ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)
# decoder: either best path decoding or beam search decoding
if self.useBeamSearch:
decoder = tf.nn.ctc_beam_search_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=False)
else:
if self.decoderType == DecoderType.BestPath:
decoder = tf.nn.ctc_greedy_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen)
elif self.decoderType == DecoderType.BeamSearch:
decoder = tf.nn.ctc_beam_search_decoder(inputs=ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=False)
elif self.decoderType == DecoderType.WordBeamSearch:
# import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch)
word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so')
# prepare information about language (dictionary, characters in dataset, characters forming words)
chars = str().join(self.charList)
wordChars = open('../model/wordCharList.txt').read().splitlines()[0]
corpus = open('../data/corpus.txt').read()
# decode using the "Words" mode of word beam search
decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(ctcIn3dTBC, dim=2), 50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))
# return a CTC operation to compute the loss and a CTC operation to decode the RNN output
return (tf.reduce_mean(loss), decoder)
......@@ -144,14 +163,28 @@ class Model:
return (indices, values, shape)
def fromSparse(self, ctcOutput):
"extract texts from sparse tensor"
def decoderOutputToText(self, ctcOutput):
"extract texts from output of CTC decoder"
# contains string of labels for each batch element
encodedLabelStrs = [[] for i in range(Model.batchSize)]
# word beam search: label strings terminated by blank
if self.decoderType == DecoderType.WordBeamSearch:
blank=len(self.charList)
for b in range(Model.batchSize):
for label in ctcOutput[b]:
if label==blank:
break
encodedLabelStrs[b].append(label)
# TF decoders: label strings are contained in sparse tensor
else:
# ctc returns tuple, first element is SparseTensor
decoded=ctcOutput[0][0]
# go over all indices and save mapping: batch -> values
idxDict = { b : [] for b in range(Model.batchSize) }
encodedLabelStrs = [[] for i in range(Model.batchSize)]
for (idx, idx2d) in enumerate(decoded.indices):
label = decoded.values[idx]
batchElement = idx2d[0] # index according to [b,t]
......@@ -173,7 +206,7 @@ class Model:
def inferBatch(self, batch):
"feed a batch into the NN to recngnize the texts"
decoded = self.sess.run(self.decoder, { self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * Model.batchSize } )
return self.fromSparse(decoded)
return self.decoderOutputToText(decoded)
def save(self):
......
......@@ -3,7 +3,7 @@ import argparse
import cv2
import editdistance
from DataLoader import DataLoader, Batch
from Model import Model
from Model import Model, DecoderType
from SamplePreprocessor import preprocess
......@@ -13,6 +13,7 @@ class FilePaths:
fnAccuracy = '../model/accuracy.txt'
fnTrain = '../data/'
fnInfer = '../data/test.png'
fnCorpus = '../data/corpus.txt'
def train(model, loader):
......@@ -99,8 +100,15 @@ def main():
parser.add_argument("--train", help="train the NN", action="store_true")
parser.add_argument("--validate", help="validate the NN", action="store_true")
parser.add_argument("--beamsearch", help="use beam search instead of best path decoding", action="store_true")
parser.add_argument("--wordbeamsearch", help="use word beam search instead of best path decoding", action="store_true")
args = parser.parse_args()
decoderType = DecoderType.BestPath
if args.beamsearch:
decoderType = DecoderType.BeamSearch
elif args.wordbeamsearch:
decoderType = DecoderType.WordBeamSearch
# train or validate on IAM dataset
if args.train or args.validate:
# load training data, create TF model
......@@ -109,18 +117,21 @@ def main():
# save characters of model for inference mode
open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))
# save words contained in dataset into file
open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords))
# execute training or validation
if args.train:
model = Model(loader.charList, args.beamsearch)
model = Model(loader.charList, decoderType)
train(model, loader)
elif args.validate:
model = Model(loader.charList, args.beamsearch, mustRestore=True)
model = Model(loader.charList, decoderType, mustRestore=True)
validate(model, loader)
# infer text on test image
else:
print(open(FilePaths.fnAccuracy).read())
model = Model(open(FilePaths.fnCharList).read(), args.beamsearch, mustRestore=True)
model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True)
infer(model, FilePaths.fnInfer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment