diff --git a/src/dataloader_iam.py b/src/dataloader_iam.py index 654708a1dbc382349cbd478265133cc9e127ca45..3ff9d38abb4c6339b3fdebb5aaf79c4e15ef5569 100644 --- a/src/dataloader_iam.py +++ b/src/dataloader_iam.py @@ -136,7 +136,7 @@ class DataLoaderIAM: else: img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE) - imgs.append(preprocess(img, self.img_size, self.data_augmentation)) + imgs.append(preprocess(img, self.img_size, data_augmentation=self.data_augmentation)) self.curr_idx += self.batch_size return Batch(gt_texts, imgs) diff --git a/src/main.py b/src/main.py index 492cb80eebc59c95f1d0ab6d437ce1a119032fdc..340a8380af4c87383cdad519a2964c39f26cc070 100644 --- a/src/main.py +++ b/src/main.py @@ -41,7 +41,7 @@ def train(model, loader): while loader.has_next(): iter_info = loader.get_iterator_info() batch = loader.get_next() - loss = model.trainBatch(batch) + loss = model.train_batch(batch) print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}') # validate @@ -101,9 +101,9 @@ def validate(model, loader): def infer(model, fn_img): """Recognizes text in image provided by file path.""" - img = preprocess(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE), Model.img_size) + img = preprocess(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE), Model.img_size, dynamic_width=True) batch = Batch(None, [img]) - (recognized, probability) = model.infer_batch(batch, True) + recognized, probability = model.infer_batch(batch, True) print(f'Recognized: "{recognized[0]}"') print(f'Probability: {probability[0]}') diff --git a/src/model.py b/src/model.py index 63652c7a74ded766e9d0e1ce30fa6642117edc77..6c938737fc12224d212991fab79f87e99fc8f0db 100644 --- a/src/model.py +++ b/src/model.py @@ -34,7 +34,7 @@ class Model: self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') # input image batch - self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, Model.img_size[0], Model.img_size[1])) + self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, Model.img_size[1])) # setup CNN, RNN and CTC self.setup_cnn() @@ -117,7 +117,7 @@ class Model: # calc loss for each element to compute label probability self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32, - shape=[Model.max_text_len, None, len(self.char_list) + 1]) + shape=[None, None, len(self.char_list) + 1]) self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input, sequence_length=self.seq_len, ctc_merge_repeated=True) @@ -211,7 +211,7 @@ class Model: # map labels to chars for all batch elements return [str().join([self.char_list[c] for c in labelStr]) for labelStr in label_strs] - def trainBatch(self, batch): + def train_batch(self, batch): """Feed a batch into the NN to train it.""" num_batch_elements = len(batch.imgs) sparse = self.to_sparse(batch.gt_texts) @@ -259,8 +259,11 @@ class Model: if self.dump or calc_probability: eval_list.append(self.ctc_in_3d_tbc) + # sequence length + max_text_len = batch.imgs[0].shape[0] // 4 + # dict containing all tensor fed into the model - feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [Model.max_text_len] * num_batch_elements, + feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False} # evaluate model @@ -283,7 +286,7 @@ class Model: ctc_input = eval_res[1] eval_list = self.loss_per_element feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse, - self.seq_len: [Model.max_text_len] * num_batch_elements, self.is_train: False} + self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False} loss_vals = self.sess.run(eval_list, feed_dict) probs = np.exp(-loss_vals) diff --git a/src/preprocess.py b/src/preprocess.py index 272713682bb74743e06eee5907bedc8070314a60..322802425bd2e3e1472809d7ee065891bf4ffc5a 100644 --- a/src/preprocess.py +++ b/src/preprocess.py @@ -4,7 +4,7 @@ import cv2 import numpy as np -def preprocess(img, img_size, data_augmentation=False): +def preprocess(img, img_size, dynamic_width=False, data_augmentation=False): "put img into target img of size imgSize, transpose for TF and normalize gray-values" # there are damaged files in IAM dataset - just use black image instead @@ -51,17 +51,25 @@ def preprocess(img, img_size, data_augmentation=False): # no data augmentation else: - # center image - wt, ht = img_size - h, w = img.shape - f = min(wt / w, ht / h) - tx = (wt - w * f) / 2 - ty = (ht - h * f) / 2 + if dynamic_width: + ht = img_size[1] + h, w = img.shape + f = ht / h + wt = int(f * w) + wt = wt + (4 - wt) % 4 + tx = (wt - w * f) / 2 + ty = 0 + else: + wt, ht = img_size + h, w = img.shape + f = min(wt / w, ht / h) + tx = (wt - w * f) / 2 + ty = (ht - h * f) / 2 # map image into target image M = np.float32([[f, 0, tx], [0, f, ty]]) - target = np.ones(img_size[::-1]) * 255 / 2 - img = cv2.warpAffine(img, M, dsize=img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) + target = np.ones([ht, wt]) * 255 / 2 + img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT) # transpose for TF img = cv2.transpose(img) @@ -75,7 +83,7 @@ if __name__ == '__main__': import matplotlib.pyplot as plt img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE) - img_aug = preprocess(img, (128, 32), True) + img_aug = preprocess(img, (128, 32), data_augmentation=False, dynamic_width=True) plt.subplot(121) plt.imshow(img, cmap='gray') plt.subplot(122)