From 6c76bca2e6f03d4c5f9e386c28c4b9eee7a9c6e4 Mon Sep 17 00:00:00 2001
From: Harald Scheidl <harald@newpc.com>
Date: Mon, 24 May 2021 14:25:42 +0200
Subject: [PATCH] started multi word training

---
 src/dataloader_iam.py                  | 71 ++++++++++++++++++++------
 src/main.py                            |  4 +-
 src/model.py                           |  2 +-
 src/{preprocess.py => preprocessor.py} | 11 ++--
 4 files changed, 66 insertions(+), 22 deletions(-)
 rename src/{preprocess.py => preprocessor.py} (91%)

diff --git a/src/dataloader_iam.py b/src/dataloader_iam.py
index 3ff9d38..84d89e2 100644
--- a/src/dataloader_iam.py
+++ b/src/dataloader_iam.py
@@ -7,7 +7,7 @@ import lmdb
 import numpy as np
 from path import Path
 
-from preprocess import preprocess
+from preprocessor import preprocess
 
 Sample = namedtuple('Sample', 'gt_text, file_path')
 Batch = namedtuple('Batch', 'gt_texts, imgs')
@@ -16,7 +16,7 @@ Batch = namedtuple('Batch', 'gt_texts, imgs')
 class DataLoaderIAM:
     "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
 
-    def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True):
+    def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True, multi_word_mode=False):
         """Loader for dataset."""
 
         assert data_dir.exists()
@@ -25,6 +25,9 @@ class DataLoaderIAM:
         if fast:
             self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
 
+        self.max_text_len=max_text_len
+        self.multi_word_mode = multi_word_mode
+
         self.data_augmentation = False
         self.curr_idx = 0
         self.batch_size = batch_size
@@ -52,7 +55,7 @@ class DataLoaderIAM:
                 continue
 
             # GT text are columns starting at 9
-            gt_text = self.truncate_label(' '.join(line_split[8:]), max_text_len)
+            gt_text = ' '.join(line_split[8:])
             chars = chars.union(set(list(gt_text)))
 
             # put sample into list
@@ -71,10 +74,13 @@ class DataLoaderIAM:
         self.train_set()
 
         # list of all chars in dataset
+        if multi_word_mode:
+            chars.add(' ')
         self.char_list = sorted(list(chars))
 
+
     @staticmethod
-    def truncate_label(text, max_text_len):
+    def _truncate_label(text, max_text_len):
         """
         Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
         labels. Repeat letters cost double because of the blank symbol needing to be inserted.
@@ -121,22 +127,57 @@ class DataLoaderIAM:
         else:
             return self.curr_idx < len(self.samples)  # val set: allow last batch to be smaller
 
+    def _get_img(self, i):
+        if self.fast:
+            with self.env.begin() as txn:
+                basename = Path(self.samples[i].file_path).basename()
+                data = txn.get(basename.encode("ascii"))
+                img = pickle.loads(data)
+        else:
+            img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
+
+        return img
+
+    @staticmethod
+    def _simulate_multi_words(imgs, gt_texts):
+        batch_size = len(imgs)
+
+        res_imgs = []
+        res_gt_texts = []
+
+        word_sep_space = 30
+
+        for i in range(batch_size):
+            j = (i + 1) % batch_size
+
+            img_left = imgs[i]
+            img_right = imgs[j]
+            h = max(img_left.shape[0], img_right.shape[0])
+            w = img_left.shape[1] + img_right.shape[1] + word_sep_space
+
+            target = np.ones([h, w], np.uint8) * 255
+
+            target[-img_left.shape[0]:, :img_left.shape[1]] = img_left
+            target[-img_right.shape[0]:, -img_right.shape[1]:] = img_right
+
+            res_imgs.append(target)
+            res_gt_texts.append(gt_texts[i] + ' ' + gt_texts[j])
+
+        return res_imgs, res_gt_texts
+
     def get_next(self):
-        "iterator"
+        "Iterator."
         batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
+
+        imgs = [self._get_img(i) for i in batch_range]
         gt_texts = [self.samples[i].gt_text for i in batch_range]
 
-        imgs = []
-        for i in batch_range:
-            if self.fast:
-                with self.env.begin() as txn:
-                    basename = Path(self.samples[i].file_path).basename()
-                    data = txn.get(basename.encode("ascii"))
-                    img = pickle.loads(data)
-            else:
-                img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
+        if self.multi_word_mode:
+            imgs, gt_texts = self._simulate_multi_words(imgs, gt_texts)
 
-            imgs.append(preprocess(img, self.img_size, data_augmentation=self.data_augmentation))
+        # apply data augmentation to images
+        imgs = [preprocess(img, self.img_size, data_augmentation=self.data_augmentation) for img in imgs]
+        gt_texts = [self._truncate_label(gt_text, self.max_text_len) for gt_text in gt_texts]
 
         self.curr_idx += self.batch_size
         return Batch(gt_texts, imgs)
diff --git a/src/main.py b/src/main.py
index 340a838..0f730e9 100644
--- a/src/main.py
+++ b/src/main.py
@@ -7,7 +7,7 @@ from path import Path
 
 from dataloader_iam import DataLoaderIAM, Batch
 from model import Model, DecoderType
-from preprocess import preprocess
+from preprocessor import preprocess
 
 
 class FilePaths:
@@ -130,7 +130,7 @@ def main():
     # train or validate on IAM dataset
     if args.train or args.validate:
         # load training data, create TF model
-        loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast)
+        loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast, True)
 
         # save characters of model for inference mode
         open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list))
diff --git a/src/model.py b/src/model.py
index 6c93873..3010644 100644
--- a/src/model.py
+++ b/src/model.py
@@ -259,7 +259,7 @@ class Model:
         if self.dump or calc_probability:
             eval_list.append(self.ctc_in_3d_tbc)
 
-        # sequence length
+        # sequence length depends on input image size (model downsizes width by 4)
         max_text_len = batch.imgs[0].shape[0] // 4
 
         # dict containing all tensor fed into the model
diff --git a/src/preprocess.py b/src/preprocessor.py
similarity index 91%
rename from src/preprocess.py
rename to src/preprocessor.py
index 3228024..183f67a 100644
--- a/src/preprocess.py
+++ b/src/preprocessor.py
@@ -3,6 +3,9 @@ import random
 import cv2
 import numpy as np
 
+# TODO: change to class
+# TODO: do multi-word simulation in here!
+
 
 def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
     "put img into target img of size imgSize, transpose for TF and normalize gray-values"
@@ -33,14 +36,14 @@ def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
         wt, ht = img_size
         h, w = img.shape
         f = min(wt / w, ht / h)
-        fx = f * np.random.uniform(0.75, 1.25)
-        fy = f * np.random.uniform(0.75, 1.25)
+        fx = f * np.random.uniform(0.75, 1.1)
+        fy = f * np.random.uniform(0.75, 1.1)
 
         # random position around center
         txc = (wt - w * fx) / 2
         tyc = (ht - h * fy) / 2
-        freedom_x = max((wt - fx * w) / 2, 0) + wt / 10
-        freedom_y = max((ht - fy * h) / 2, 0) + ht / 10
+        freedom_x = max((wt - fx * w) / 2, 0)
+        freedom_y = max((ht - fy * h) / 2, 0)
         tx = txc + np.random.uniform(-freedom_x, freedom_x)
         ty = tyc + np.random.uniform(-freedom_y, freedom_y)
 
-- 
GitLab