From d807bc44ea7d54dc7e7d9d407b61ed72a770f64a Mon Sep 17 00:00:00 2001 From: Harald Scheidl <harald@newpc.com> Date: Tue, 1 Mar 2022 18:41:08 +0100 Subject: [PATCH] in validation mode use charset from trained model and not from dataset, see issue #127 --- model/.gitignore | 5 +++-- src/main.py | 54 +++++++++++++++++++++++++++++------------------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/model/.gitignore b/model/.gitignore index 86d0cb2..13bb8e8 100644 --- a/model/.gitignore +++ b/model/.gitignore @@ -1,4 +1,5 @@ # Ignore everything in this directory * -# Except this file -!.gitignore \ No newline at end of file +# Except this file and wordCharList.txt +!.gitignore +wordCharList.txt \ No newline at end of file diff --git a/src/main.py b/src/main.py index 3b3ac43..420461e 100644 --- a/src/main.py +++ b/src/main.py @@ -36,6 +36,11 @@ def write_summary(char_error_rates: List[float], word_accuracies: List[float]) - json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) +def char_list_from_file() -> List[str]: + with open(FilePaths.fn_char_list) as f: + return list(f.read()) + + def train(model: Model, loader: DataLoaderIAM, line_mode: bool, @@ -45,7 +50,7 @@ def train(model: Model, summary_char_error_rates = [] summary_word_accuracies = [] preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode) - best_char_error_rate = float('inf') # best valdiation character error rate + best_char_error_rate = float('inf') # best validation character error rate no_improvement_since = 0 # number of epochs no improvement of character error rate occurred # stop training after this number of epochs without improvement while True: @@ -133,8 +138,8 @@ def infer(model: Model, fn_img: Path) -> None: print(f'Probability: {probability[0]}') -def main(): - """Main function.""" +def parse_args() -> argparse.Namespace: + """Parses arguments from the command line.""" parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'validate', 'infer'], default='infer') @@ -146,41 +151,48 @@ def main(): parser.add_argument('--img_file', help='Image used for inference.', type=Path, default='../data/word.png') parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25) parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true') - args = parser.parse_args() - # set chosen CTC decoder + return parser.parse_args() + + +def main(): + """Main function.""" + + # parse arguments and set CTC decoder + args = parse_args() decoder_mapping = {'bestpath': DecoderType.BestPath, 'beamsearch': DecoderType.BeamSearch, 'wordbeamsearch': DecoderType.WordBeamSearch} decoder_type = decoder_mapping[args.decoder] - # train or validate on IAM dataset - if args.mode in ['train', 'validate']: - # load training data, create TF model + # train the model + if args.mode == 'train': loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast) - char_list = loader.char_list # when in line mode, take care to have a whitespace in the char list + char_list = loader.char_list if args.line_mode and ' ' not in char_list: char_list = [' '] + char_list - # save characters of model for inference mode - open(FilePaths.fn_char_list, 'w').write(''.join(char_list)) + # save characters and words + with open(FilePaths.fn_char_list, 'w') as f: + f.write(''.join(char_list)) + + with open(FilePaths.fn_corpus, 'w') as f: + f.write(' '.join(loader.train_words + loader.validation_words)) - # save words contained in dataset into file - open(FilePaths.fn_corpus, 'w').write(' '.join(loader.train_words + loader.validation_words)) + model = Model(char_list, decoder_type) + train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping) - # execute training or validation - if args.mode == 'train': - model = Model(char_list, decoder_type) - train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping) - elif args.mode == 'validate': - model = Model(char_list, decoder_type, must_restore=True) - validate(model, loader, args.line_mode) + # evaluate it on the validation set + elif args.mode == 'validate': + loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast) + model = Model(char_list_from_file(), decoder_type, must_restore=True) + validate(model, loader, args.line_mode) # infer text on test image elif args.mode == 'infer': - model = Model(list(open(FilePaths.fn_char_list).read()), decoder_type, must_restore=True, dump=args.dump) + model = Model(char_list_from_file(), decoder_type, must_restore=True, dump=args.dump) infer(model, args.img_file) -- GitLab