From 3507e59d8649f36ed503c85074cb0309996c1c14 Mon Sep 17 00:00:00 2001
From: Harald Scheidl <harald@newpc.com>
Date: Mon, 24 May 2021 18:13:56 +0200
Subject: [PATCH] reworked preprocessor

---
 src/dataloader_iam.py |  66 +-------------
 src/main.py           |  42 ++++++---
 src/model.py          |  13 +--
 src/preprocessor.py   | 208 +++++++++++++++++++++++++++---------------
 4 files changed, 169 insertions(+), 160 deletions(-)

diff --git a/src/dataloader_iam.py b/src/dataloader_iam.py
index 84d89e2..5ce2f17 100644
--- a/src/dataloader_iam.py
+++ b/src/dataloader_iam.py
@@ -7,16 +7,14 @@ import lmdb
 import numpy as np
 from path import Path
 
-from preprocessor import preprocess
-
 Sample = namedtuple('Sample', 'gt_text, file_path')
-Batch = namedtuple('Batch', 'gt_texts, imgs')
+Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')
 
 
 class DataLoaderIAM:
     "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
 
-    def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True, multi_word_mode=False):
+    def __init__(self, data_dir, batch_size, fast=True):
         """Loader for dataset."""
 
         assert data_dir.exists()
@@ -25,13 +23,9 @@ class DataLoaderIAM:
         if fast:
             self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
 
-        self.max_text_len=max_text_len
-        self.multi_word_mode = multi_word_mode
-
         self.data_augmentation = False
         self.curr_idx = 0
         self.batch_size = batch_size
-        self.img_size = img_size
         self.samples = []
 
         f = open(data_dir / 'gt/words.txt')
@@ -74,28 +68,8 @@ class DataLoaderIAM:
         self.train_set()
 
         # list of all chars in dataset
-        if multi_word_mode:
-            chars.add(' ')
         self.char_list = sorted(list(chars))
 
-
-    @staticmethod
-    def _truncate_label(text, max_text_len):
-        """
-        Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
-        labels. Repeat letters cost double because of the blank symbol needing to be inserted.
-        If a too-long label is provided, ctc_loss returns an infinite gradient.
-        """
-        cost = 0
-        for i in range(len(text)):
-            if i != 0 and text[i] == text[i - 1]:
-                cost += 2
-            else:
-                cost += 1
-            if cost > max_text_len:
-                return text[:i]
-        return text
-
     def train_set(self):
         """Switch to randomly chosen subset of training set."""
         self.data_augmentation = True
@@ -138,33 +112,6 @@ class DataLoaderIAM:
 
         return img
 
-    @staticmethod
-    def _simulate_multi_words(imgs, gt_texts):
-        batch_size = len(imgs)
-
-        res_imgs = []
-        res_gt_texts = []
-
-        word_sep_space = 30
-
-        for i in range(batch_size):
-            j = (i + 1) % batch_size
-
-            img_left = imgs[i]
-            img_right = imgs[j]
-            h = max(img_left.shape[0], img_right.shape[0])
-            w = img_left.shape[1] + img_right.shape[1] + word_sep_space
-
-            target = np.ones([h, w], np.uint8) * 255
-
-            target[-img_left.shape[0]:, :img_left.shape[1]] = img_left
-            target[-img_right.shape[0]:, -img_right.shape[1]:] = img_right
-
-            res_imgs.append(target)
-            res_gt_texts.append(gt_texts[i] + ' ' + gt_texts[j])
-
-        return res_imgs, res_gt_texts
-
     def get_next(self):
         "Iterator."
         batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
@@ -172,12 +119,5 @@ class DataLoaderIAM:
         imgs = [self._get_img(i) for i in batch_range]
         gt_texts = [self.samples[i].gt_text for i in batch_range]
 
-        if self.multi_word_mode:
-            imgs, gt_texts = self._simulate_multi_words(imgs, gt_texts)
-
-        # apply data augmentation to images
-        imgs = [preprocess(img, self.img_size, data_augmentation=self.data_augmentation) for img in imgs]
-        gt_texts = [self._truncate_label(gt_text, self.max_text_len) for gt_text in gt_texts]
-
         self.curr_idx += self.batch_size
-        return Batch(gt_texts, imgs)
+        return Batch(imgs, gt_texts, self.batch_size)
diff --git a/src/main.py b/src/main.py
index 0f730e9..d54d87e 100644
--- a/src/main.py
+++ b/src/main.py
@@ -7,7 +7,7 @@ from path import Path
 
 from dataloader_iam import DataLoaderIAM, Batch
 from model import Model, DecoderType
-from preprocessor import preprocess
+from preprocessor import Preprocessor
 
 
 class FilePaths:
@@ -18,18 +18,22 @@ class FilePaths:
     fn_corpus = '../data/corpus.txt'
 
 
+train_img_size = (128, 32)
+
+
 def write_summary(char_error_rates, word_accuracies):
     with open(FilePaths.fn_summary, 'w') as f:
         json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
 
 
-def train(model, loader):
+def train(model, loader, line_mode):
     """Trains NN."""
     epoch = 0  # number of training epochs since start
     summary_char_error_rates = []
     summary_word_accuracies = []
+    preprocessor = Preprocessor(train_img_size, data_augmentation=True, line_mode=line_mode)
     best_char_error_rate = float('inf')  # best valdiation character error rate
-    no_improvement_since = 0  # number of epochs no improvement of character error rate occured
+    no_improvement_since = 0  # number of epochs no improvement of character error rate occurred
     early_stopping = 25  # stop training after this number of epochs without improvement
     while True:
         epoch += 1
@@ -41,11 +45,12 @@ def train(model, loader):
         while loader.has_next():
             iter_info = loader.get_iterator_info()
             batch = loader.get_next()
+            batch = preprocessor.process_batch(batch)
             loss = model.train_batch(batch)
             print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}')
 
         # validate
-        char_error_rate, word_accuracy = validate(model, loader)
+        char_error_rate, word_accuracy = validate(model, loader, line_mode)
 
         # write summary
         summary_char_error_rates.append(char_error_rate)
@@ -68,10 +73,11 @@ def train(model, loader):
             break
 
 
-def validate(model, loader):
+def validate(model, loader, line_mode):
     """Validates NN."""
     print('Validate NN')
     loader.validation_set()
+    preprocessor = Preprocessor(train_img_size, line_mode=line_mode)
     num_char_err = 0
     num_char_total = 0
     num_word_ok = 0
@@ -80,7 +86,8 @@ def validate(model, loader):
         iter_info = loader.get_iterator_info()
         print(f'Batch: {iter_info[0]} / {iter_info[1]}')
         batch = loader.get_next()
-        (recognized, _) = model.infer_batch(batch)
+        batch = preprocessor.process_batch(batch)
+        recognized, _ = model.infer_batch(batch)
 
         print('Ground truth -> Recognized')
         for i in range(len(recognized)):
@@ -101,8 +108,9 @@ def validate(model, loader):
 
 def infer(model, fn_img):
     """Recognizes text in image provided by file path."""
-    img = preprocess(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE), Model.img_size, dynamic_width=True)
-    batch = Batch(None, [img])
+    preprocessor = Preprocessor(train_img_size, dynamic_width=True)
+    img = preprocessor.process_img(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE))
+    batch = Batch([img], None)
     recognized, probability = model.infer_batch(batch, True)
     print(f'Recognized: "{recognized[0]}"')
     print(f'Probability: {probability[0]}')
@@ -119,6 +127,7 @@ def main():
     parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
     parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
     parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
+    line_mode = True
     args = parser.parse_args()
 
     # set chosen CTC decoder
@@ -130,21 +139,24 @@ def main():
     # train or validate on IAM dataset
     if args.train or args.validate:
         # load training data, create TF model
-        loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast, True)
+        loader = DataLoaderIAM(args.data_dir, args.batch_size, fast=args.fast)
+        char_list = loader.char_list
+        if line_mode:
+            char_list = [' '] + char_list
 
         # save characters of model for inference mode
-        open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list))
+        open(FilePaths.fn_char_list, 'w').write(''.join(char_list))
 
         # save words contained in dataset into file
-        open(FilePaths.fn_corpus, 'w').write(str(' ').join(loader.train_words + loader.validation_words))
+        open(FilePaths.fn_corpus, 'w').write(' '.join(loader.train_words + loader.validation_words))
 
         # execute training or validation
         if args.train:
-            model = Model(loader.char_list, decoder_type)
-            train(model, loader)
+            model = Model(char_list, decoder_type)
+            train(model, loader, line_mode)
         elif args.validate:
-            model = Model(loader.char_list, decoder_type, must_restore=True)
-            validate(model, loader)
+            model = Model(char_list, decoder_type, must_restore=True)
+            validate(model, loader, line_mode)
 
     # infer text on test image
     else:
diff --git a/src/model.py b/src/model.py
index 3010644..d15afeb 100644
--- a/src/model.py
+++ b/src/model.py
@@ -18,10 +18,6 @@ class DecoderType:
 class Model:
     """Minimalistic TF model for HTR."""
 
-    # model constants
-    img_size = (128, 32)
-    max_text_len = 32
-
     def __init__(self, char_list, decoder_type=DecoderType.BestPath, must_restore=False, dump=False):
         """Init model: add CNN, RNN and CTC and initialize TF."""
         self.dump = dump
@@ -34,7 +30,7 @@ class Model:
         self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train')
 
         # input image batch
-        self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, Model.img_size[1]))
+        self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None))
 
         # setup CNN, RNN and CTC
         self.setup_cnn()
@@ -57,7 +53,7 @@ class Model:
         # list of parameters for the layers
         kernel_vals = [5, 5, 3, 3, 3]
         feature_vals = [1, 32, 64, 128, 128, 256]
-        stride_vals = poolVals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)]
+        stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)]
         num_layers = len(stride_vals)
 
         # create layers
@@ -69,7 +65,7 @@ class Model:
             conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1))
             conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train)
             relu = tf.nn.relu(conv_norm)
-            pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1),
+            pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1),
                                     strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID')
 
         self.cnn_out_4d = pool
@@ -214,10 +210,11 @@ class Model:
     def train_batch(self, batch):
         """Feed a batch into the NN to train it."""
         num_batch_elements = len(batch.imgs)
+        max_text_len = batch.imgs[0].shape[0] // 4
         sparse = self.to_sparse(batch.gt_texts)
         eval_list = [self.optimizer, self.loss]
         feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse,
-                     self.seq_len: [Model.max_text_len] * num_batch_elements, self.is_train: True}
+                     self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True}
         _, loss_val = self.sess.run(eval_list, feed_dict)
         self.batches_trained += 1
         return loss_val
diff --git a/src/preprocessor.py b/src/preprocessor.py
index 183f67a..de9f426 100644
--- a/src/preprocessor.py
+++ b/src/preprocessor.py
@@ -3,90 +3,150 @@ import random
 import cv2
 import numpy as np
 
-# TODO: change to class
-# TODO: do multi-word simulation in here!
+from dataloader_iam import Batch
 
 
-def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
-    "put img into target img of size imgSize, transpose for TF and normalize gray-values"
-
-    # there are damaged files in IAM dataset - just use black image instead
-    if img is None:
-        img = np.zeros(img_size[::-1])
-
-    # data augmentation
-    img = img.astype(np.float)
-    if data_augmentation:
-        # photometric data augmentation
-        if random.random() < 0.25:
-            rand_odd = lambda: random.randint(1, 3) * 2 + 1
-            img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
-        if random.random() < 0.25:
-            img = cv2.dilate(img, np.ones((3, 3)))
-        if random.random() < 0.25:
-            img = cv2.erode(img, np.ones((3, 3)))
-        if random.random() < 0.5:
-            img = img * (0.25 + random.random() * 0.75)
-        if random.random() < 0.25:
-            img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 50), 0, 255)
-        if random.random() < 0.1:
-            img = 255 - img
-
-        # geometric data augmentation
-        wt, ht = img_size
-        h, w = img.shape
-        f = min(wt / w, ht / h)
-        fx = f * np.random.uniform(0.75, 1.1)
-        fy = f * np.random.uniform(0.75, 1.1)
-
-        # random position around center
-        txc = (wt - w * fx) / 2
-        tyc = (ht - h * fy) / 2
-        freedom_x = max((wt - fx * w) / 2, 0)
-        freedom_y = max((ht - fy * h) / 2, 0)
-        tx = txc + np.random.uniform(-freedom_x, freedom_x)
-        ty = tyc + np.random.uniform(-freedom_y, freedom_y)
-
-        # map image into target image
-        M = np.float32([[fx, 0, tx], [0, fy, ty]])
-        target = np.ones(img_size[::-1]) * 255 / 2
-        img = cv2.warpAffine(img, M, dsize=img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
-
-    # no data augmentation
-    else:
-        if dynamic_width:
-            ht = img_size[1]
-            h, w = img.shape
-            f = ht / h
-            wt = int(f * w)
-            wt = wt + (4 - wt) % 4
-            tx = (wt - w * f) / 2
-            ty = 0
-        else:
-            wt, ht = img_size
+# TODO: change to class
+# TODO: do multi-word simulation in here!
+class Preprocessor:
+    def __init__(self, img_size, dynamic_width=False, data_augmentation=False, line_mode=False):
+        self.img_size = img_size
+        self.dynamic_width = dynamic_width
+        self.data_augmentation = data_augmentation
+        self.line_mode = line_mode
+
+    @staticmethod
+    def _truncate_label(text, max_text_len):
+        """
+        Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
+        labels. Repeat letters cost double because of the blank symbol needing to be inserted.
+        If a too-long label is provided, ctc_loss returns an infinite gradient.
+        """
+        cost = 0
+        for i in range(len(text)):
+            if i != 0 and text[i] == text[i - 1]:
+                cost += 2
+            else:
+                cost += 1
+            if cost > max_text_len:
+                return text[:i]
+        return text
+
+    @staticmethod
+    def _simulate_multi_words(batch):
+
+        res_imgs = []
+        res_gt_texts = []
+
+        word_sep_space = 30
+
+        for i in range(batch.batch_size):
+            j = (i + 1) % batch.batch_size
+
+            img_left = batch.imgs[i]
+            img_right = batch.imgs[j]
+            h = max(img_left.shape[0], img_right.shape[0])
+            w = img_left.shape[1] + img_right.shape[1] + word_sep_space
+
+            target = np.ones([h, w], np.uint8) * 255
+
+            target[-img_left.shape[0]:, :img_left.shape[1]] = img_left
+            target[-img_right.shape[0]:, -img_right.shape[1]:] = img_right
+
+            res_imgs.append(target)
+            res_gt_texts.append(batch.gt_texts[i] + ' ' + batch.gt_texts[j])
+
+        return Batch(res_imgs, res_gt_texts, batch.batch_size)
+
+    def process_img(self, img):
+        "put img into target img of size imgSize, transpose for TF and normalize gray-values"
+
+        # there are damaged files in IAM dataset - just use black image instead
+        if img is None:
+            img = np.zeros(self.img_size[::-1])
+
+        # data augmentation
+        img = img.astype(np.float)
+        if self.data_augmentation:
+            # photometric data augmentation
+            if random.random() < 0.25:
+                rand_odd = lambda: random.randint(1, 3) * 2 + 1
+                img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
+            if random.random() < 0.25:
+                img = cv2.dilate(img, np.ones((3, 3)))
+            if random.random() < 0.25:
+                img = cv2.erode(img, np.ones((3, 3)))
+            if random.random() < 0.5:
+                img = img * (0.25 + random.random() * 0.75)
+            if random.random() < 0.25:
+                img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 50), 0, 255)
+            if random.random() < 0.1:
+                img = 255 - img
+
+            # geometric data augmentation
+            wt, ht = self.img_size
             h, w = img.shape
             f = min(wt / w, ht / h)
-            tx = (wt - w * f) / 2
-            ty = (ht - h * f) / 2
-
-        # map image into target image
-        M = np.float32([[f, 0, tx], [0, f, ty]])
-        target = np.ones([ht, wt]) * 255 / 2
-        img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
-
-    # transpose for TF
-    img = cv2.transpose(img)
-
-    # convert to range [-1, 1]
-    img = img / 255 - 0.5
-    return img
+            fx = f * np.random.uniform(0.75, 1.1)
+            fy = f * np.random.uniform(0.75, 1.1)
+
+            # random position around center
+            txc = (wt - w * fx) / 2
+            tyc = (ht - h * fy) / 2
+            freedom_x = max((wt - fx * w) / 2, 0)
+            freedom_y = max((ht - fy * h) / 2, 0)
+            tx = txc + np.random.uniform(-freedom_x, freedom_x)
+            ty = tyc + np.random.uniform(-freedom_y, freedom_y)
+
+            # map image into target image
+            M = np.float32([[fx, 0, tx], [0, fy, ty]])
+            target = np.ones(self.img_size[::-1]) * 255 / 2
+            img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
+
+        # no data augmentation
+        else:
+            if self.dynamic_width:
+                ht = self.img_size[1]
+                h, w = img.shape
+                f = ht / h
+                wt = int(f * w)
+                wt = wt + (4 - wt) % 4
+                tx = (wt - w * f) / 2
+                ty = 0
+            else:
+                wt, ht = self.img_size
+                h, w = img.shape
+                f = min(wt / w, ht / h)
+                tx = (wt - w * f) / 2
+                ty = (ht - h * f) / 2
+
+            # map image into target image
+            M = np.float32([[f, 0, tx], [0, f, ty]])
+            target = np.ones([ht, wt]) * 255 / 2
+            img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
+
+        # transpose for TF
+        img = cv2.transpose(img)
+
+        # convert to range [-1, 1]
+        img = img / 255 - 0.5
+        return img
+
+    def process_batch(self, batch):
+        if self.line_mode:
+            batch = self._simulate_multi_words(batch)
+
+        res_imgs = [self.process_img(img) for img in batch.imgs]
+        max_text_len = res_imgs[0].shape[0] // 4
+        res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
+        return Batch(res_imgs, res_gt_texts, batch.batch_size)
 
 
 if __name__ == '__main__':
     import matplotlib.pyplot as plt
 
     img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
-    img_aug = preprocess(img, (128, 32), data_augmentation=False, dynamic_width=True)
+    img_aug = Preprocessor((128, 32), 32, dynamic_width=True, data_augmentation=False).process_img(img)
     plt.subplot(121)
     plt.imshow(img, cmap='gray')
     plt.subplot(122)
-- 
GitLab