Skip to content
Snippets Groups Projects
Commit 6c76bca2 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

started multi word training

parent 7edd1f86
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ import lmdb
import numpy as np
from path import Path
from preprocess import preprocess
from preprocessor import preprocess
Sample = namedtuple('Sample', 'gt_text, file_path')
Batch = namedtuple('Batch', 'gt_texts, imgs')
......@@ -16,7 +16,7 @@ Batch = namedtuple('Batch', 'gt_texts, imgs')
class DataLoaderIAM:
"loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True):
def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True, multi_word_mode=False):
"""Loader for dataset."""
assert data_dir.exists()
......@@ -25,6 +25,9 @@ class DataLoaderIAM:
if fast:
self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
self.max_text_len=max_text_len
self.multi_word_mode = multi_word_mode
self.data_augmentation = False
self.curr_idx = 0
self.batch_size = batch_size
......@@ -52,7 +55,7 @@ class DataLoaderIAM:
continue
# GT text are columns starting at 9
gt_text = self.truncate_label(' '.join(line_split[8:]), max_text_len)
gt_text = ' '.join(line_split[8:])
chars = chars.union(set(list(gt_text)))
# put sample into list
......@@ -71,10 +74,13 @@ class DataLoaderIAM:
self.train_set()
# list of all chars in dataset
if multi_word_mode:
chars.add(' ')
self.char_list = sorted(list(chars))
@staticmethod
def truncate_label(text, max_text_len):
def _truncate_label(text, max_text_len):
"""
Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
labels. Repeat letters cost double because of the blank symbol needing to be inserted.
......@@ -121,13 +127,7 @@ class DataLoaderIAM:
else:
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller
def get_next(self):
"iterator"
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
gt_texts = [self.samples[i].gt_text for i in batch_range]
imgs = []
for i in batch_range:
def _get_img(self, i):
if self.fast:
with self.env.begin() as txn:
basename = Path(self.samples[i].file_path).basename()
......@@ -136,7 +136,48 @@ class DataLoaderIAM:
else:
img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
imgs.append(preprocess(img, self.img_size, data_augmentation=self.data_augmentation))
return img
@staticmethod
def _simulate_multi_words(imgs, gt_texts):
batch_size = len(imgs)
res_imgs = []
res_gt_texts = []
word_sep_space = 30
for i in range(batch_size):
j = (i + 1) % batch_size
img_left = imgs[i]
img_right = imgs[j]
h = max(img_left.shape[0], img_right.shape[0])
w = img_left.shape[1] + img_right.shape[1] + word_sep_space
target = np.ones([h, w], np.uint8) * 255
target[-img_left.shape[0]:, :img_left.shape[1]] = img_left
target[-img_right.shape[0]:, -img_right.shape[1]:] = img_right
res_imgs.append(target)
res_gt_texts.append(gt_texts[i] + ' ' + gt_texts[j])
return res_imgs, res_gt_texts
def get_next(self):
"Iterator."
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
imgs = [self._get_img(i) for i in batch_range]
gt_texts = [self.samples[i].gt_text for i in batch_range]
if self.multi_word_mode:
imgs, gt_texts = self._simulate_multi_words(imgs, gt_texts)
# apply data augmentation to images
imgs = [preprocess(img, self.img_size, data_augmentation=self.data_augmentation) for img in imgs]
gt_texts = [self._truncate_label(gt_text, self.max_text_len) for gt_text in gt_texts]
self.curr_idx += self.batch_size
return Batch(gt_texts, imgs)
......@@ -7,7 +7,7 @@ from path import Path
from dataloader_iam import DataLoaderIAM, Batch
from model import Model, DecoderType
from preprocess import preprocess
from preprocessor import preprocess
class FilePaths:
......@@ -130,7 +130,7 @@ def main():
# train or validate on IAM dataset
if args.train or args.validate:
# load training data, create TF model
loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast)
loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast, True)
# save characters of model for inference mode
open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list))
......
......@@ -259,7 +259,7 @@ class Model:
if self.dump or calc_probability:
eval_list.append(self.ctc_in_3d_tbc)
# sequence length
# sequence length depends on input image size (model downsizes width by 4)
max_text_len = batch.imgs[0].shape[0] // 4
# dict containing all tensor fed into the model
......
......@@ -3,6 +3,9 @@ import random
import cv2
import numpy as np
# TODO: change to class
# TODO: do multi-word simulation in here!
def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
"put img into target img of size imgSize, transpose for TF and normalize gray-values"
......@@ -33,14 +36,14 @@ def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
wt, ht = img_size
h, w = img.shape
f = min(wt / w, ht / h)
fx = f * np.random.uniform(0.75, 1.25)
fy = f * np.random.uniform(0.75, 1.25)
fx = f * np.random.uniform(0.75, 1.1)
fy = f * np.random.uniform(0.75, 1.1)
# random position around center
txc = (wt - w * fx) / 2
tyc = (ht - h * fy) / 2
freedom_x = max((wt - fx * w) / 2, 0) + wt / 10
freedom_y = max((ht - fy * h) / 2, 0) + ht / 10
freedom_x = max((wt - fx * w) / 2, 0)
freedom_y = max((ht - fy * h) / 2, 0)
tx = txc + np.random.uniform(-freedom_x, freedom_x)
ty = tyc + np.random.uniform(-freedom_y, freedom_y)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment