Skip to content
Snippets Groups Projects
Select Git revision
  • 8b2e511d65435a828ef7f262811c7ebf7b60020d
  • master default protected
  • exec_auto_adjust_trace
  • let_variables
  • v1.4.1
  • v1.4.0
  • v1.3.0
  • v1.2.0
  • v1.1.0
  • v1.0.0
10 results

postBuild

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    main.py 5.93 KiB
    import argparse
    import json
    
    import cv2
    import editdistance
    from path import Path
    
    from dataloader_iam import DataLoaderIAM, Batch
    from model import Model, DecoderType
    from preprocess import preprocess
    
    
    class FilePaths:
        """Filenames and paths to data."""
        fn_char_list = '../model/charList.txt'
        fn_summary = '../model/summary.json'
        fn_infer = '../data/test.png'
        fn_corpus = '../data/corpus.txt'
    
    
    def write_summary(char_error_rates, word_accuracies):
        with open(FilePaths.fn_summary, 'w') as f:
            json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
    
    
    def train(model, loader):
        """Trains NN."""
        epoch = 0  # number of training epochs since start
        summary_char_error_rates = []
        summary_word_accuracies = []
        best_char_error_rate = float('inf')  # best valdiation character error rate
        no_improvement_since = 0  # number of epochs no improvement of character error rate occured
        early_stopping = 25  # stop training after this number of epochs without improvement
        while True:
            epoch += 1
            print('Epoch:', epoch)
    
            # train
            print('Train NN')
            loader.train_set()
            while loader.has_next():
                iter_info = loader.get_iterator_info()
                batch = loader.get_next()
                loss = model.train_batch(batch)
                print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}')
    
            # validate
            char_error_rate, word_accuracy = validate(model, loader)
    
            # write summary
            summary_char_error_rates.append(char_error_rate)
            summary_word_accuracies.append(word_accuracy)
            write_summary(summary_char_error_rates, summary_word_accuracies)
    
            # if best validation accuracy so far, save model parameters
            if char_error_rate < best_char_error_rate:
                print('Character error rate improved, save model')
                best_char_error_rate = char_error_rate
                no_improvement_since = 0
                model.save()
            else:
                print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%')
                no_improvement_since += 1
    
            # stop training if no more improvement in the last x epochs
            if no_improvement_since >= early_stopping:
                print(f'No more improvement since {early_stopping} epochs. Training stopped.')
                break
    
    
    def validate(model, loader):
        """Validates NN."""
        print('Validate NN')
        loader.validation_set()
        num_char_err = 0
        num_char_total = 0
        num_word_ok = 0
        num_word_total = 0
        while loader.has_next():
            iter_info = loader.get_iterator_info()
            print(f'Batch: {iter_info[0]} / {iter_info[1]}')
            batch = loader.get_next()
            (recognized, _) = model.infer_batch(batch)
    
            print('Ground truth -> Recognized')
            for i in range(len(recognized)):
                num_word_ok += 1 if batch.gt_texts[i] == recognized[i] else 0
                num_word_total += 1
                dist = editdistance.eval(recognized[i], batch.gt_texts[i])
                num_char_err += dist
                num_char_total += len(batch.gt_texts[i])
                print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->',
                      '"' + recognized[i] + '"')
    
        # print validation result
        char_error_rate = num_char_err / num_char_total
        word_accuracy = num_word_ok / num_word_total
        print(f'Character error rate: {char_error_rate * 100.0}%. Word accuracy: {word_accuracy * 100.0}%.')
        return char_error_rate, word_accuracy
    
    
    def infer(model, fn_img):
        """Recognizes text in image provided by file path."""
        img = preprocess(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE), Model.img_size, dynamic_width=True)
        batch = Batch(None, [img])
        recognized, probability = model.infer_batch(batch, True)
        print(f'Recognized: "{recognized[0]}"')
        print(f'Probability: {probability[0]}')
    
    
    def main():
        """Main function."""
        parser = argparse.ArgumentParser()
        parser.add_argument('--train', help='train the NN', action='store_true')
        parser.add_argument('--validate', help='validate the NN', action='store_true')
        parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath',
                            help='CTC decoder')
        parser.add_argument('--batch_size', help='batch size', type=int, default=100)
        parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
        parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
        parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
        args = parser.parse_args()
    
        # set chosen CTC decoder
        decoder_mapping = {'bestpath': DecoderType.BestPath,
                           'beamsearch': DecoderType.BeamSearch,
                           'wordbeamsearch': DecoderType.WordBeamSearch}
        decoder_type = decoder_mapping[args.decoder]
    
        # train or validate on IAM dataset
        if args.train or args.validate:
            # load training data, create TF model
            loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast)
    
            # save characters of model for inference mode
            open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list))
    
            # save words contained in dataset into file
            open(FilePaths.fn_corpus, 'w').write(str(' ').join(loader.train_words + loader.validation_words))
    
            # execute training or validation
            if args.train:
                model = Model(loader.char_list, decoder_type)
                train(model, loader)
            elif args.validate:
                model = Model(loader.char_list, decoder_type, must_restore=True)
                validate(model, loader)
    
        # infer text on test image
        else:
            model = Model(open(FilePaths.fn_char_list).read(), decoder_type, must_restore=True, dump=args.dump)
            infer(model, FilePaths.fn_infer)
    
    
    if __name__ == '__main__':
        main()