Skip to content
Snippets Groups Projects
Select Git revision
  • ab6220d63acb31e63e03d77328f45883328fd0ea
  • master default protected
2 results

preprocessor.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    preprocessor.py 6.93 KiB
    import random
    from typing import Tuple
    
    import cv2
    import numpy as np
    
    from dataloader_iam import Batch
    
    
    class Preprocessor:
        def __init__(self,
                     img_size: Tuple[int, int],
                     padding: int = 0,
                     dynamic_width: bool = False,
                     data_augmentation: bool = False,
                     line_mode: bool = False) -> None:
            # dynamic width only supported when no data augmentation happens
            assert not (dynamic_width and data_augmentation)
            # when padding is on, we need dynamic width enabled
            assert not (padding > 0 and not dynamic_width)
    
            self.img_size = img_size
            self.padding = padding
            self.dynamic_width = dynamic_width
            self.data_augmentation = data_augmentation
            self.line_mode = line_mode
    
        @staticmethod
        def _truncate_label(text: str, max_text_len: int) -> str:
            """
            Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
            labels. Repeat letters cost double because of the blank symbol needing to be inserted.
            If a too-long label is provided, ctc_loss returns an infinite gradient.
            """
            cost = 0
            for i in range(len(text)):
                if i != 0 and text[i] == text[i - 1]:
                    cost += 2
                else:
                    cost += 1
                if cost > max_text_len:
                    return text[:i]
            return text
    
        def _simulate_text_line(self, batch: Batch) -> Batch:
            """Create image of a text line by pasting multiple word images into an image."""
    
            default_word_sep = 30
            default_num_words = 5
    
            # go over all batch elements
            res_imgs = []
            res_gt_texts = []
            for i in range(batch.batch_size):
                # number of words to put into current line
                num_words = random.randint(1, 8) if self.data_augmentation else default_num_words
    
                # concat ground truth texts
                curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
                res_gt_texts.append(curr_gt)
    
                # put selected word images into list, compute target image size
                sel_imgs = []
                word_seps = [0]
                h = 0
                w = 0
                for j in range(num_words):
                    curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
                    curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
                    h = max(h, curr_sel_img.shape[0])
                    w += curr_sel_img.shape[1]
                    sel_imgs.append(curr_sel_img)
                    if j + 1 < num_words:
                        w += curr_word_sep
                        word_seps.append(curr_word_sep)
    
                # put all selected word images into target image
                target = np.ones([h, w], np.uint8) * 255
                x = 0
                for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
                    x += curr_word_sep
                    y = (h - curr_sel_img.shape[0]) // 2
                    target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
                    x += curr_sel_img.shape[1]
    
                # put image of line into result
                res_imgs.append(target)
    
            return Batch(res_imgs, res_gt_texts, batch.batch_size)
    
        def process_img(self, img: np.ndarray) -> np.ndarray:
            """Resize to target size, apply data augmentation."""
    
            # there are damaged files in IAM dataset - just use black image instead
            if img is None:
                img = np.zeros(self.img_size[::-1])
    
            # data augmentation
            img = img.astype(np.float)
            if self.data_augmentation:
                # photometric data augmentation
                if random.random() < 0.25:
                    def rand_odd():
                        return random.randint(1, 3) * 2 + 1
                    img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
                if random.random() < 0.25:
                    img = cv2.dilate(img, np.ones((3, 3)))
                if random.random() < 0.25:
                    img = cv2.erode(img, np.ones((3, 3)))
    
                # geometric data augmentation
                wt, ht = self.img_size
                h, w = img.shape
                f = min(wt / w, ht / h)
                fx = f * np.random.uniform(0.75, 1.05)
                fy = f * np.random.uniform(0.75, 1.05)
    
                # random position around center
                txc = (wt - w * fx) / 2
                tyc = (ht - h * fy) / 2
                freedom_x = max((wt - fx * w) / 2, 0)
                freedom_y = max((ht - fy * h) / 2, 0)
                tx = txc + np.random.uniform(-freedom_x, freedom_x)
                ty = tyc + np.random.uniform(-freedom_y, freedom_y)
    
                # map image into target image
                M = np.float32([[fx, 0, tx], [0, fy, ty]])
                target = np.ones(self.img_size[::-1]) * 255
                img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
    
                # photometric data augmentation
                if random.random() < 0.5:
                    img = img * (0.25 + random.random() * 0.75)
                if random.random() < 0.25:
                    img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
                if random.random() < 0.1:
                    img = 255 - img
    
            # no data augmentation
            else:
                if self.dynamic_width:
                    ht = self.img_size[1]
                    h, w = img.shape
                    f = ht / h
                    wt = int(f * w + self.padding)
                    wt = wt + (4 - wt) % 4
                    tx = (wt - w * f) / 2
                    ty = 0
                else:
                    wt, ht = self.img_size
                    h, w = img.shape
                    f = min(wt / w, ht / h)
                    tx = (wt - w * f) / 2
                    ty = (ht - h * f) / 2
    
                # map image into target image
                M = np.float32([[f, 0, tx], [0, f, ty]])
                target = np.ones([ht, wt]) * 255
                img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
    
            # transpose for TF
            img = cv2.transpose(img)
    
            # convert to range [-1, 1]
            img = img / 255 - 0.5
            return img
    
        def process_batch(self, batch: Batch) -> Batch:
            if self.line_mode:
                batch = self._simulate_text_line(batch)
    
            res_imgs = [self.process_img(img) for img in batch.imgs]
            max_text_len = res_imgs[0].shape[0] // 4
            res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
            return Batch(res_imgs, res_gt_texts, batch.batch_size)
    
    
    def main():
        import matplotlib.pyplot as plt
    
        img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
        img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
        plt.subplot(121)
        plt.imshow(img, cmap='gray')
        plt.subplot(122)
        plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
        plt.show()
    
    
    if __name__ == '__main__':
        main()