From d807bc44ea7d54dc7e7d9d407b61ed72a770f64a Mon Sep 17 00:00:00 2001
From: Harald Scheidl <harald@newpc.com>
Date: Tue, 1 Mar 2022 18:41:08 +0100
Subject: [PATCH] in validation mode use charset from trained model and not
 from dataset, see issue #127

---
 model/.gitignore |  5 +++--
 src/main.py      | 54 +++++++++++++++++++++++++++++-------------------
 2 files changed, 36 insertions(+), 23 deletions(-)

diff --git a/model/.gitignore b/model/.gitignore
index 86d0cb2..13bb8e8 100644
--- a/model/.gitignore
+++ b/model/.gitignore
@@ -1,4 +1,5 @@
 # Ignore everything in this directory
 *
-# Except this file
-!.gitignore
\ No newline at end of file
+# Except this file and wordCharList.txt
+!.gitignore
+wordCharList.txt
\ No newline at end of file
diff --git a/src/main.py b/src/main.py
index 3b3ac43..420461e 100644
--- a/src/main.py
+++ b/src/main.py
@@ -36,6 +36,11 @@ def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -
         json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
 
 
+def char_list_from_file() -> List[str]:
+    with open(FilePaths.fn_char_list) as f:
+        return list(f.read())
+
+
 def train(model: Model,
           loader: DataLoaderIAM,
           line_mode: bool,
@@ -45,7 +50,7 @@ def train(model: Model,
     summary_char_error_rates = []
     summary_word_accuracies = []
     preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode)
-    best_char_error_rate = float('inf')  # best valdiation character error rate
+    best_char_error_rate = float('inf')  # best validation character error rate
     no_improvement_since = 0  # number of epochs no improvement of character error rate occurred
     # stop training after this number of epochs without improvement
     while True:
@@ -133,8 +138,8 @@ def infer(model: Model, fn_img: Path) -> None:
     print(f'Probability: {probability[0]}')
 
 
-def main():
-    """Main function."""
+def parse_args() -> argparse.Namespace:
+    """Parses arguments from the command line."""
     parser = argparse.ArgumentParser()
 
     parser.add_argument('--mode', choices=['train', 'validate', 'infer'], default='infer')
@@ -146,41 +151,48 @@ def main():
     parser.add_argument('--img_file', help='Image used for inference.', type=Path, default='../data/word.png')
     parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25)
     parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true')
-    args = parser.parse_args()
 
-    # set chosen CTC decoder
+    return parser.parse_args()
+
+
+def main():
+    """Main function."""
+
+    # parse arguments and set CTC decoder
+    args = parse_args()
     decoder_mapping = {'bestpath': DecoderType.BestPath,
                        'beamsearch': DecoderType.BeamSearch,
                        'wordbeamsearch': DecoderType.WordBeamSearch}
     decoder_type = decoder_mapping[args.decoder]
 
-    # train or validate on IAM dataset
-    if args.mode in ['train', 'validate']:
-        # load training data, create TF model
+    # train the model
+    if args.mode == 'train':
         loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast)
-        char_list = loader.char_list
 
         # when in line mode, take care to have a whitespace in the char list
+        char_list = loader.char_list
         if args.line_mode and ' ' not in char_list:
             char_list = [' '] + char_list
 
-        # save characters of model for inference mode
-        open(FilePaths.fn_char_list, 'w').write(''.join(char_list))
+        # save characters and words
+        with open(FilePaths.fn_char_list, 'w') as f:
+            f.write(''.join(char_list))
+
+        with open(FilePaths.fn_corpus, 'w') as f:
+            f.write(' '.join(loader.train_words + loader.validation_words))
 
-        # save words contained in dataset into file
-        open(FilePaths.fn_corpus, 'w').write(' '.join(loader.train_words + loader.validation_words))
+        model = Model(char_list, decoder_type)
+        train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping)
 
-        # execute training or validation
-        if args.mode == 'train':
-            model = Model(char_list, decoder_type)
-            train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping)
-        elif args.mode == 'validate':
-            model = Model(char_list, decoder_type, must_restore=True)
-            validate(model, loader, args.line_mode)
+    # evaluate it on the validation set
+    elif args.mode == 'validate':
+        loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast)
+        model = Model(char_list_from_file(), decoder_type, must_restore=True)
+        validate(model, loader, args.line_mode)
 
     # infer text on test image
     elif args.mode == 'infer':
-        model = Model(list(open(FilePaths.fn_char_list).read()), decoder_type, must_restore=True, dump=args.dump)
+        model = Model(char_list_from_file(), decoder_type, must_restore=True, dump=args.dump)
         infer(model, args.img_file)
 
 
-- 
GitLab