From 6c76bca2e6f03d4c5f9e386c28c4b9eee7a9c6e4 Mon Sep 17 00:00:00 2001 From: Harald Scheidl <harald@newpc.com> Date: Mon, 24 May 2021 14:25:42 +0200 Subject: [PATCH] started multi word training --- src/dataloader_iam.py | 71 ++++++++++++++++++++------ src/main.py | 4 +- src/model.py | 2 +- src/{preprocess.py => preprocessor.py} | 11 ++-- 4 files changed, 66 insertions(+), 22 deletions(-) rename src/{preprocess.py => preprocessor.py} (91%) diff --git a/src/dataloader_iam.py b/src/dataloader_iam.py index 3ff9d38..84d89e2 100644 --- a/src/dataloader_iam.py +++ b/src/dataloader_iam.py @@ -7,7 +7,7 @@ import lmdb import numpy as np from path import Path -from preprocess import preprocess +from preprocessor import preprocess Sample = namedtuple('Sample', 'gt_text, file_path') Batch = namedtuple('Batch', 'gt_texts, imgs') @@ -16,7 +16,7 @@ Batch = namedtuple('Batch', 'gt_texts, imgs') class DataLoaderIAM: "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" - def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True): + def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True, multi_word_mode=False): """Loader for dataset.""" assert data_dir.exists() @@ -25,6 +25,9 @@ class DataLoaderIAM: if fast: self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True) + self.max_text_len=max_text_len + self.multi_word_mode = multi_word_mode + self.data_augmentation = False self.curr_idx = 0 self.batch_size = batch_size @@ -52,7 +55,7 @@ class DataLoaderIAM: continue # GT text are columns starting at 9 - gt_text = self.truncate_label(' '.join(line_split[8:]), max_text_len) + gt_text = ' '.join(line_split[8:]) chars = chars.union(set(list(gt_text))) # put sample into list @@ -71,10 +74,13 @@ class DataLoaderIAM: self.train_set() # list of all chars in dataset + if multi_word_mode: + chars.add(' ') self.char_list = sorted(list(chars)) + @staticmethod - def truncate_label(text, max_text_len): + def _truncate_label(text, max_text_len): """ Function ctc_loss can't compute loss if it cannot find a mapping between text label and input labels. Repeat letters cost double because of the blank symbol needing to be inserted. @@ -121,22 +127,57 @@ class DataLoaderIAM: else: return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller + def _get_img(self, i): + if self.fast: + with self.env.begin() as txn: + basename = Path(self.samples[i].file_path).basename() + data = txn.get(basename.encode("ascii")) + img = pickle.loads(data) + else: + img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) + + return img + + @staticmethod + def _simulate_multi_words(imgs, gt_texts): + batch_size = len(imgs) + + res_imgs = [] + res_gt_texts = [] + + word_sep_space = 30 + + for i in range(batch_size): + j = (i + 1) % batch_size + + img_left = imgs[i] + img_right = imgs[j] + h = max(img_left.shape[0], img_right.shape[0]) + w = img_left.shape[1] + img_right.shape[1] + word_sep_space + + target = np.ones([h, w], np.uint8) * 255 + + target[-img_left.shape[0]:, :img_left.shape[1]] = img_left + target[-img_right.shape[0]:, -img_right.shape[1]:] = img_right + + res_imgs.append(target) + res_gt_texts.append(gt_texts[i] + ' ' + gt_texts[j]) + + return res_imgs, res_gt_texts + def get_next(self): - "iterator" + "Iterator." batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) + + imgs = [self._get_img(i) for i in batch_range] gt_texts = [self.samples[i].gt_text for i in batch_range] - imgs = [] - for i in batch_range: - if self.fast: - with self.env.begin() as txn: - basename = Path(self.samples[i].file_path).basename() - data = txn.get(basename.encode("ascii")) - img = pickle.loads(data) - else: - img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) + if self.multi_word_mode: + imgs, gt_texts = self._simulate_multi_words(imgs, gt_texts) - imgs.append(preprocess(img, self.img_size, data_augmentation=self.data_augmentation)) + # apply data augmentation to images + imgs = [preprocess(img, self.img_size, data_augmentation=self.data_augmentation) for img in imgs] + gt_texts = [self._truncate_label(gt_text, self.max_text_len) for gt_text in gt_texts] self.curr_idx += self.batch_size return Batch(gt_texts, imgs) diff --git a/src/main.py b/src/main.py index 340a838..0f730e9 100644 --- a/src/main.py +++ b/src/main.py @@ -7,7 +7,7 @@ from path import Path from dataloader_iam import DataLoaderIAM, Batch from model import Model, DecoderType -from preprocess import preprocess +from preprocessor import preprocess class FilePaths: @@ -130,7 +130,7 @@ def main(): # train or validate on IAM dataset if args.train or args.validate: # load training data, create TF model - loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast) + loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast, True) # save characters of model for inference mode open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list)) diff --git a/src/model.py b/src/model.py index 6c93873..3010644 100644 --- a/src/model.py +++ b/src/model.py @@ -259,7 +259,7 @@ class Model: if self.dump or calc_probability: eval_list.append(self.ctc_in_3d_tbc) - # sequence length + # sequence length depends on input image size (model downsizes width by 4) max_text_len = batch.imgs[0].shape[0] // 4 # dict containing all tensor fed into the model diff --git a/src/preprocess.py b/src/preprocessor.py similarity index 91% rename from src/preprocess.py rename to src/preprocessor.py index 3228024..183f67a 100644 --- a/src/preprocess.py +++ b/src/preprocessor.py @@ -3,6 +3,9 @@ import random import cv2 import numpy as np +# TODO: change to class +# TODO: do multi-word simulation in here! + def preprocess(img, img_size, dynamic_width=False, data_augmentation=False): "put img into target img of size imgSize, transpose for TF and normalize gray-values" @@ -33,14 +36,14 @@ def preprocess(img, img_size, dynamic_width=False, data_augmentation=False): wt, ht = img_size h, w = img.shape f = min(wt / w, ht / h) - fx = f * np.random.uniform(0.75, 1.25) - fy = f * np.random.uniform(0.75, 1.25) + fx = f * np.random.uniform(0.75, 1.1) + fy = f * np.random.uniform(0.75, 1.1) # random position around center txc = (wt - w * fx) / 2 tyc = (ht - h * fy) / 2 - freedom_x = max((wt - fx * w) / 2, 0) + wt / 10 - freedom_y = max((ht - fy * h) / 2, 0) + ht / 10 + freedom_x = max((wt - fx * w) / 2, 0) + freedom_y = max((ht - fy * h) / 2, 0) tx = txc + np.random.uniform(-freedom_x, freedom_x) ty = tyc + np.random.uniform(-freedom_y, freedom_y) -- GitLab