From ca133566515c6e06a8fc3e7de05609a23983ae20 Mon Sep 17 00:00:00 2001 From: Harald Scheidl <harald@newpc.com> Date: Mon, 1 Feb 2021 17:56:15 +0100 Subject: [PATCH] geometric data augmentation --- src/SamplePreprocessor.py | 76 ++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 33 deletions(-) diff --git a/src/SamplePreprocessor.py b/src/SamplePreprocessor.py index 5547eec..1d5ffd9 100644 --- a/src/SamplePreprocessor.py +++ b/src/SamplePreprocessor.py @@ -9,53 +9,62 @@ def preprocess(img, imgSize, dataAugmentation=False): # there are damaged files in IAM dataset - just use black image instead if img is None: - img = np.zeros([imgSize[1], imgSize[0]]) + img = np.zeros(imgSize[::-1]) + # data augmentation img = img.astype(np.float) - - # increase dataset size by applying random stretches to the images if dataAugmentation: + # photometric data augmentation if random.random() < 0.25: rand_odd = lambda: random.randint(1, 3) * 2 + 1 img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) if random.random() < 0.25: - img = cv2.dilate(img,np.ones((3,3))) + img = cv2.dilate(img, np.ones((3, 3))) if random.random() < 0.25: - img = cv2.erode(img,np.ones((3,3))) + img = cv2.erode(img, np.ones((3, 3))) if random.random() < 0.5: - img = img * (0.25 + random.random() * 0.75) + img = img * (0.5 + random.random() * 0.5) if random.random() < 0.25: - img = np.clip(img + (np.random.random(img.shape)-0.5) * random.randint(1, 50), 0, 255) + img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 50), 0, 255) if random.random() < 0.1: img = 255 - img - stretch = random.random() - 0.5 # -0.5 .. +0.5 - wStretched = max(int(img.shape[1] * (1 + stretch)), 1) # random width, but at least 1 - img = cv2.resize(img, (wStretched, img.shape[0])) # stretch horizontally by factor 0.5 .. 1.5 + # geometric data augmentation + wt, ht = imgSize + h, w = img.shape + f = min(wt / w, ht / h) + fx = f * np.random.uniform(0.75, 1.25) + fy = f * np.random.uniform(0.75, 1.25) - # create target image and copy sample image into it - (wt, ht) = imgSize - (h, w) = img.shape - fx = w / wt - fy = h / ht - f = max(fx, fy) - newSize = (max(min(wt, int(w / f)), 1), - max(min(ht, int(h / f)), 1)) # scale according to f (result at least 1 and at most wt or ht) - img = cv2.resize(img, newSize) - target = np.ones([ht, wt]) * 127.5 - - r_freedom = target.shape[0] - img.shape[0] - c_freedom = target.shape[1] - img.shape[1] - - if dataAugmentation: - r_off, c_off = random.randint(0, r_freedom), random.randint(0, c_freedom) + # random position around center + txc = (wt - w * fx) / 2 + tyc = (ht - h * fy) / 2 + freedom_x = wt // 10 + freedom_y = ht // 10 + tx = txc + np.random.randint(-freedom_x, freedom_x) + ty = tyc + np.random.randint(-freedom_y, freedom_y) + + # map image into target image + M = np.float32([[fx, 0, tx], [0, fy, ty]]) + target = np.ones(imgSize[::-1]) * 255 / 2 + img = cv2.warpAffine(img, M, dsize=imgSize, dst=target, borderMode=cv2.BORDER_TRANSPARENT) + + # no data augmentation else: - r_off, c_off = r_freedom // 2, c_freedom // 2 + # center image + wt, ht = imgSize + h, w = img.shape + f = min(wt / w, ht / h) + tx = (wt - w * f) / 2 + ty = (ht - h * f) / 2 - target[r_off:img.shape[0]+r_off, c_off:img.shape[1]+c_off] = img + # map image into target image + M = np.float32([[f, 0, tx], [0, f, ty]]) + target = np.ones(imgSize[::-1]) * 255 / 2 + img = cv2.warpAffine(img, M, dsize=imgSize, dst=target, borderMode=cv2.BORDER_TRANSPARENT) # transpose for TF - img = cv2.transpose(target) + img = cv2.transpose(img) # convert to range [-1, 1] img = img / 255 - 0.5 @@ -64,10 +73,11 @@ def preprocess(img, imgSize, dataAugmentation=False): if __name__ == '__main__': import matplotlib.pyplot as plt + img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE) - img_aug = preprocess(img, (128, 32), True) + img_aug = preprocess(img, (128, 32), False) plt.subplot(121) - plt.imshow(img) + plt.imshow(img, cmap='gray') plt.subplot(122) - plt.imshow(cv2.transpose(img_aug)) - plt.show() \ No newline at end of file + plt.imshow(cv2.transpose(img_aug), cmap='gray') + plt.show() -- GitLab