Select Git revision
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 5.93 KiB
import argparse
import json
import cv2
import editdistance
from path import Path
from dataloader_iam import DataLoaderIAM, Batch
from model import Model, DecoderType
from preprocess import preprocess
class FilePaths:
"""Filenames and paths to data."""
fn_char_list = '../model/charList.txt'
fn_summary = '../model/summary.json'
fn_infer = '../data/test.png'
fn_corpus = '../data/corpus.txt'
def write_summary(char_error_rates, word_accuracies):
with open(FilePaths.fn_summary, 'w') as f:
json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
def train(model, loader):
"""Trains NN."""
epoch = 0 # number of training epochs since start
summary_char_error_rates = []
summary_word_accuracies = []
best_char_error_rate = float('inf') # best valdiation character error rate
no_improvement_since = 0 # number of epochs no improvement of character error rate occured
early_stopping = 25 # stop training after this number of epochs without improvement
while True:
epoch += 1
print('Epoch:', epoch)
# train
print('Train NN')
loader.train_set()
while loader.has_next():
iter_info = loader.get_iterator_info()
batch = loader.get_next()
loss = model.train_batch(batch)
print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}')
# validate
char_error_rate, word_accuracy = validate(model, loader)
# write summary
summary_char_error_rates.append(char_error_rate)
summary_word_accuracies.append(word_accuracy)
write_summary(summary_char_error_rates, summary_word_accuracies)
# if best validation accuracy so far, save model parameters
if char_error_rate < best_char_error_rate:
print('Character error rate improved, save model')
best_char_error_rate = char_error_rate
no_improvement_since = 0
model.save()
else:
print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%')
no_improvement_since += 1
# stop training if no more improvement in the last x epochs
if no_improvement_since >= early_stopping:
print(f'No more improvement since {early_stopping} epochs. Training stopped.')
break
def validate(model, loader):
"""Validates NN."""
print('Validate NN')
loader.validation_set()
num_char_err = 0
num_char_total = 0
num_word_ok = 0
num_word_total = 0
while loader.has_next():
iter_info = loader.get_iterator_info()
print(f'Batch: {iter_info[0]} / {iter_info[1]}')
batch = loader.get_next()
(recognized, _) = model.infer_batch(batch)
print('Ground truth -> Recognized')
for i in range(len(recognized)):
num_word_ok += 1 if batch.gt_texts[i] == recognized[i] else 0
num_word_total += 1
dist = editdistance.eval(recognized[i], batch.gt_texts[i])
num_char_err += dist
num_char_total += len(batch.gt_texts[i])
print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->',
'"' + recognized[i] + '"')
# print validation result
char_error_rate = num_char_err / num_char_total
word_accuracy = num_word_ok / num_word_total
print(f'Character error rate: {char_error_rate * 100.0}%. Word accuracy: {word_accuracy * 100.0}%.')
return char_error_rate, word_accuracy
def infer(model, fn_img):
"""Recognizes text in image provided by file path."""
img = preprocess(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE), Model.img_size, dynamic_width=True)
batch = Batch(None, [img])
recognized, probability = model.infer_batch(batch, True)
print(f'Recognized: "{recognized[0]}"')
print(f'Probability: {probability[0]}')
def main():
"""Main function."""
parser = argparse.ArgumentParser()
parser.add_argument('--train', help='train the NN', action='store_true')
parser.add_argument('--validate', help='validate the NN', action='store_true')
parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath',
help='CTC decoder')
parser.add_argument('--batch_size', help='batch size', type=int, default=100)
parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
args = parser.parse_args()
# set chosen CTC decoder
decoder_mapping = {'bestpath': DecoderType.BestPath,
'beamsearch': DecoderType.BeamSearch,
'wordbeamsearch': DecoderType.WordBeamSearch}
decoder_type = decoder_mapping[args.decoder]
# train or validate on IAM dataset
if args.train or args.validate:
# load training data, create TF model
loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast)
# save characters of model for inference mode
open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list))
# save words contained in dataset into file
open(FilePaths.fn_corpus, 'w').write(str(' ').join(loader.train_words + loader.validation_words))
# execute training or validation
if args.train:
model = Model(loader.char_list, decoder_type)
train(model, loader)
elif args.validate:
model = Model(loader.char_list, decoder_type, must_restore=True)
validate(model, loader)
# infer text on test image
else:
model = Model(open(FilePaths.fn_char_list).read(), decoder_type, must_restore=True, dump=args.dump)
infer(model, FilePaths.fn_infer)
if __name__ == '__main__':
main()