Skip to content
Snippets Groups Projects
Commit 6c76bca2 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

started multi word training

parent 7edd1f86
Branches
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ import lmdb ...@@ -7,7 +7,7 @@ import lmdb
import numpy as np import numpy as np
from path import Path from path import Path
from preprocess import preprocess from preprocessor import preprocess
Sample = namedtuple('Sample', 'gt_text, file_path') Sample = namedtuple('Sample', 'gt_text, file_path')
Batch = namedtuple('Batch', 'gt_texts, imgs') Batch = namedtuple('Batch', 'gt_texts, imgs')
...@@ -16,7 +16,7 @@ Batch = namedtuple('Batch', 'gt_texts, imgs') ...@@ -16,7 +16,7 @@ Batch = namedtuple('Batch', 'gt_texts, imgs')
class DataLoaderIAM: class DataLoaderIAM:
"loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True): def __init__(self, data_dir, batch_size, img_size, max_text_len, fast=True, multi_word_mode=False):
"""Loader for dataset.""" """Loader for dataset."""
assert data_dir.exists() assert data_dir.exists()
...@@ -25,6 +25,9 @@ class DataLoaderIAM: ...@@ -25,6 +25,9 @@ class DataLoaderIAM:
if fast: if fast:
self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True) self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
self.max_text_len=max_text_len
self.multi_word_mode = multi_word_mode
self.data_augmentation = False self.data_augmentation = False
self.curr_idx = 0 self.curr_idx = 0
self.batch_size = batch_size self.batch_size = batch_size
...@@ -52,7 +55,7 @@ class DataLoaderIAM: ...@@ -52,7 +55,7 @@ class DataLoaderIAM:
continue continue
# GT text are columns starting at 9 # GT text are columns starting at 9
gt_text = self.truncate_label(' '.join(line_split[8:]), max_text_len) gt_text = ' '.join(line_split[8:])
chars = chars.union(set(list(gt_text))) chars = chars.union(set(list(gt_text)))
# put sample into list # put sample into list
...@@ -71,10 +74,13 @@ class DataLoaderIAM: ...@@ -71,10 +74,13 @@ class DataLoaderIAM:
self.train_set() self.train_set()
# list of all chars in dataset # list of all chars in dataset
if multi_word_mode:
chars.add(' ')
self.char_list = sorted(list(chars)) self.char_list = sorted(list(chars))
@staticmethod @staticmethod
def truncate_label(text, max_text_len): def _truncate_label(text, max_text_len):
""" """
Function ctc_loss can't compute loss if it cannot find a mapping between text label and input Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
labels. Repeat letters cost double because of the blank symbol needing to be inserted. labels. Repeat letters cost double because of the blank symbol needing to be inserted.
...@@ -121,13 +127,7 @@ class DataLoaderIAM: ...@@ -121,13 +127,7 @@ class DataLoaderIAM:
else: else:
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller
def get_next(self): def _get_img(self, i):
"iterator"
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
gt_texts = [self.samples[i].gt_text for i in batch_range]
imgs = []
for i in batch_range:
if self.fast: if self.fast:
with self.env.begin() as txn: with self.env.begin() as txn:
basename = Path(self.samples[i].file_path).basename() basename = Path(self.samples[i].file_path).basename()
...@@ -136,7 +136,48 @@ class DataLoaderIAM: ...@@ -136,7 +136,48 @@ class DataLoaderIAM:
else: else:
img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
imgs.append(preprocess(img, self.img_size, data_augmentation=self.data_augmentation)) return img
@staticmethod
def _simulate_multi_words(imgs, gt_texts):
batch_size = len(imgs)
res_imgs = []
res_gt_texts = []
word_sep_space = 30
for i in range(batch_size):
j = (i + 1) % batch_size
img_left = imgs[i]
img_right = imgs[j]
h = max(img_left.shape[0], img_right.shape[0])
w = img_left.shape[1] + img_right.shape[1] + word_sep_space
target = np.ones([h, w], np.uint8) * 255
target[-img_left.shape[0]:, :img_left.shape[1]] = img_left
target[-img_right.shape[0]:, -img_right.shape[1]:] = img_right
res_imgs.append(target)
res_gt_texts.append(gt_texts[i] + ' ' + gt_texts[j])
return res_imgs, res_gt_texts
def get_next(self):
"Iterator."
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
imgs = [self._get_img(i) for i in batch_range]
gt_texts = [self.samples[i].gt_text for i in batch_range]
if self.multi_word_mode:
imgs, gt_texts = self._simulate_multi_words(imgs, gt_texts)
# apply data augmentation to images
imgs = [preprocess(img, self.img_size, data_augmentation=self.data_augmentation) for img in imgs]
gt_texts = [self._truncate_label(gt_text, self.max_text_len) for gt_text in gt_texts]
self.curr_idx += self.batch_size self.curr_idx += self.batch_size
return Batch(gt_texts, imgs) return Batch(gt_texts, imgs)
...@@ -7,7 +7,7 @@ from path import Path ...@@ -7,7 +7,7 @@ from path import Path
from dataloader_iam import DataLoaderIAM, Batch from dataloader_iam import DataLoaderIAM, Batch
from model import Model, DecoderType from model import Model, DecoderType
from preprocess import preprocess from preprocessor import preprocess
class FilePaths: class FilePaths:
...@@ -130,7 +130,7 @@ def main(): ...@@ -130,7 +130,7 @@ def main():
# train or validate on IAM dataset # train or validate on IAM dataset
if args.train or args.validate: if args.train or args.validate:
# load training data, create TF model # load training data, create TF model
loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast) loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.img_size, Model.max_text_len, args.fast, True)
# save characters of model for inference mode # save characters of model for inference mode
open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list)) open(FilePaths.fn_char_list, 'w').write(str().join(loader.char_list))
......
...@@ -259,7 +259,7 @@ class Model: ...@@ -259,7 +259,7 @@ class Model:
if self.dump or calc_probability: if self.dump or calc_probability:
eval_list.append(self.ctc_in_3d_tbc) eval_list.append(self.ctc_in_3d_tbc)
# sequence length # sequence length depends on input image size (model downsizes width by 4)
max_text_len = batch.imgs[0].shape[0] // 4 max_text_len = batch.imgs[0].shape[0] // 4
# dict containing all tensor fed into the model # dict containing all tensor fed into the model
......
...@@ -3,6 +3,9 @@ import random ...@@ -3,6 +3,9 @@ import random
import cv2 import cv2
import numpy as np import numpy as np
# TODO: change to class
# TODO: do multi-word simulation in here!
def preprocess(img, img_size, dynamic_width=False, data_augmentation=False): def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
"put img into target img of size imgSize, transpose for TF and normalize gray-values" "put img into target img of size imgSize, transpose for TF and normalize gray-values"
...@@ -33,14 +36,14 @@ def preprocess(img, img_size, dynamic_width=False, data_augmentation=False): ...@@ -33,14 +36,14 @@ def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
wt, ht = img_size wt, ht = img_size
h, w = img.shape h, w = img.shape
f = min(wt / w, ht / h) f = min(wt / w, ht / h)
fx = f * np.random.uniform(0.75, 1.25) fx = f * np.random.uniform(0.75, 1.1)
fy = f * np.random.uniform(0.75, 1.25) fy = f * np.random.uniform(0.75, 1.1)
# random position around center # random position around center
txc = (wt - w * fx) / 2 txc = (wt - w * fx) / 2
tyc = (ht - h * fy) / 2 tyc = (ht - h * fy) / 2
freedom_x = max((wt - fx * w) / 2, 0) + wt / 10 freedom_x = max((wt - fx * w) / 2, 0)
freedom_y = max((ht - fy * h) / 2, 0) + ht / 10 freedom_y = max((ht - fy * h) / 2, 0)
tx = txc + np.random.uniform(-freedom_x, freedom_x) tx = txc + np.random.uniform(-freedom_x, freedom_x)
ty = tyc + np.random.uniform(-freedom_y, freedom_y) ty = tyc + np.random.uniform(-freedom_y, freedom_y)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment