diff --git a/.gitignore b/.gitignore
index 1b9575d156b3cb7320b4dfc73039b72535140fbf..8efcba50d500e0da4d36918183250ec11358fc3c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,5 @@
-data/words/
-data/words.txt
-data/words.7z
+data/words*
+data/words.txt*
 src/__pycache__/
 model/checkpoint
 model/snapshot-*
diff --git a/src/DataLoader.py b/src/DataLoaderIAM.py
similarity index 74%
rename from src/DataLoader.py
rename to src/DataLoaderIAM.py
index 002d5711525da4d5d5ee3931b6c4d9d305ef6733..7b4910f0323abf5f6ea7194d4bccc9ede85db674 100644
--- a/src/DataLoader.py
+++ b/src/DataLoaderIAM.py
@@ -1,138 +1,148 @@
-from __future__ import division
-from __future__ import print_function
-
-import os
-import random
-
-import cv2
-import numpy as np
-
-from SamplePreprocessor import preprocess
-
-
-class Sample:
-    "sample from the dataset"
-
-    def __init__(self, gtText, filePath):
-        self.gtText = gtText
-        self.filePath = filePath
-
-
-class Batch:
-    "batch containing images and ground truth texts"
-
-    def __init__(self, gtTexts, imgs):
-        self.imgs = np.stack(imgs, axis=0)
-        self.gtTexts = gtTexts
-
-
-class DataLoader:
-    "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
-
-    def __init__(self, filePath, batchSize, imgSize, maxTextLen):
-        "loader for dataset at given location, preprocess images and text according to parameters"
-
-        assert filePath[-1] == '/'
-
-        self.dataAugmentation = False
-        self.currIdx = 0
-        self.batchSize = batchSize
-        self.imgSize = imgSize
-        self.samples = []
-
-        f = open(filePath + 'words.txt')
-        chars = set()
-        bad_samples = []
-        bad_samples_reference = ['a01-117-05-02.png', 'r06-022-03-05.png']
-        for line in f:
-            # ignore comment line
-            if not line or line[0] == '#':
-                continue
-
-            lineSplit = line.strip().split(' ')
-            assert len(lineSplit) >= 9
-
-            # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
-            fileNameSplit = lineSplit[0].split('-')
-            fileName = filePath + 'words/' + fileNameSplit[0] + '/' + fileNameSplit[0] + '-' + fileNameSplit[1] + '/' + \
-                       lineSplit[0] + '.png'
-
-            # GT text are columns starting at 9
-            gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen)
-            chars = chars.union(set(list(gtText)))
-
-            # check if image is not empty
-            if not os.path.getsize(fileName):
-                bad_samples.append(lineSplit[0] + '.png')
-                continue
-
-            # put sample into list
-            self.samples.append(Sample(gtText, fileName))
-
-        # some images in the IAM dataset are known to be damaged, don't show warning for them
-        if set(bad_samples) != set(bad_samples_reference):
-            print("Warning, damaged images found:", bad_samples)
-            print("Damaged images expected:", bad_samples_reference)
-
-        # split into training and validation set: 95% - 5%
-        splitIdx = int(0.95 * len(self.samples))
-        self.trainSamples = self.samples[:splitIdx]
-        self.validationSamples = self.samples[splitIdx:]
-
-        # put words into lists
-        self.trainWords = [x.gtText for x in self.trainSamples]
-        self.validationWords = [x.gtText for x in self.validationSamples]
-
-        # number of randomly chosen samples per epoch for training
-        self.numTrainSamplesPerEpoch = 25000
-
-        # start with train set
-        self.trainSet()
-
-        # list of all chars in dataset
-        self.charList = sorted(list(chars))
-
-    def truncateLabel(self, text, maxTextLen):
-        # ctc_loss can't compute loss if it cannot find a mapping between text label and input
-        # labels. Repeat letters cost double because of the blank symbol needing to be inserted.
-        # If a too-long label is provided, ctc_loss returns an infinite gradient
-        cost = 0
-        for i in range(len(text)):
-            if i != 0 and text[i] == text[i - 1]:
-                cost += 2
-            else:
-                cost += 1
-            if cost > maxTextLen:
-                return text[:i]
-        return text
-
-    def trainSet(self):
-        "switch to randomly chosen subset of training set"
-        self.dataAugmentation = True
-        self.currIdx = 0
-        random.shuffle(self.trainSamples)
-        self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch]
-
-    def validationSet(self):
-        "switch to validation set"
-        self.dataAugmentation = False
-        self.currIdx = 0
-        self.samples = self.validationSamples
-
-    def getIteratorInfo(self):
-        "current batch index and overall number of batches"
-        return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize)
-
-    def hasNext(self):
-        "iterator"
-        return self.currIdx + self.batchSize <= len(self.samples)
-
-    def getNext(self):
-        "iterator"
-        batchRange = range(self.currIdx, self.currIdx + self.batchSize)
-        gtTexts = [self.samples[i].gtText for i in batchRange]
-        imgs = [
-            preprocess(cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE), self.imgSize, self.dataAugmentation)
-            for i in batchRange]
-        self.currIdx += self.batchSize
-        return Batch(gtTexts, imgs)
+import pickle
+import random
+
+import cv2
+import lmdb
+import numpy as np
+from path import Path
+
+from SamplePreprocessor import preprocess
+
+
+class Sample:
+    "sample from the dataset"
+
+    def __init__(self, gtText, filePath):
+        self.gtText = gtText
+        self.filePath = filePath
+
+
+class Batch:
+    "batch containing images and ground truth texts"
+
+    def __init__(self, gtTexts, imgs):
+        self.imgs = np.stack(imgs, axis=0)
+        self.gtTexts = gtTexts
+
+
+class DataLoaderIAM:
+    "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
+
+    def __init__(self, data_dir, batchSize, imgSize, maxTextLen, fast=True):
+        "loader for dataset at given location, preprocess images and text according to parameters"
+
+        assert data_dir.exists()
+
+        self.fast = fast
+        if fast:
+            self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
+
+        self.dataAugmentation = False
+        self.currIdx = 0
+        self.batchSize = batchSize
+        self.imgSize = imgSize
+        self.samples = []
+
+        f = open(data_dir / 'gt/words.txt')
+        chars = set()
+        bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05']  # known broken images in IAM dataset
+        for line in f:
+            # ignore comment line
+            if not line or line[0] == '#':
+                continue
+
+            lineSplit = line.strip().split(' ')
+            assert len(lineSplit) >= 9
+
+            # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
+            fileNameSplit = lineSplit[0].split('-')
+            fileName = data_dir / 'img' / fileNameSplit[0] / f'{fileNameSplit[0]}-{fileNameSplit[1]}' / lineSplit[0] + '.png'
+
+            if lineSplit[0] in bad_samples_reference:
+                print('Ignoring known broken image:', fileName)
+                continue
+
+            # GT text are columns starting at 9
+            gtText = self.truncateLabel(' '.join(lineSplit[8:]), maxTextLen)
+            chars = chars.union(set(list(gtText)))
+
+            # put sample into list
+            self.samples.append(Sample(gtText, fileName))
+
+        # split into training and validation set: 95% - 5%
+        splitIdx = int(0.95 * len(self.samples))
+        self.trainSamples = self.samples[:splitIdx]
+        self.validationSamples = self.samples[splitIdx:]
+
+        # put words into lists
+        self.trainWords = [x.gtText for x in self.trainSamples]
+        self.validationWords = [x.gtText for x in self.validationSamples]
+
+        # number of randomly chosen samples per epoch for training
+        self.numTrainSamplesPerEpoch = 25000
+
+        # start with train set
+        self.trainSet()
+
+        # list of all chars in dataset
+        self.charList = sorted(list(chars))
+
+    def truncateLabel(self, text, maxTextLen):
+        # ctc_loss can't compute loss if it cannot find a mapping between text label and input
+        # labels. Repeat letters cost double because of the blank symbol needing to be inserted.
+        # If a too-long label is provided, ctc_loss returns an infinite gradient
+        cost = 0
+        for i in range(len(text)):
+            if i != 0 and text[i] == text[i - 1]:
+                cost += 2
+            else:
+                cost += 1
+            if cost > maxTextLen:
+                return text[:i]
+        return text
+
+    def trainSet(self):
+        "switch to randomly chosen subset of training set"
+        self.dataAugmentation = True
+        self.currIdx = 0
+        random.shuffle(self.trainSamples)
+        self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch]
+
+    def validationSet(self):
+        "switch to validation set"
+        self.dataAugmentation = False
+        self.currIdx = 0
+        self.samples = self.validationSamples
+
+    def getIteratorInfo(self):
+        "current batch index and overall number of batches"
+        return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize)
+
+    def hasNext(self):
+        "iterator"
+        return self.currIdx + self.batchSize <= len(self.samples)
+
+    def getNext(self):
+        "iterator"
+        batchRange = range(self.currIdx, self.currIdx + self.batchSize)
+        gtTexts = [self.samples[i].gtText for i in batchRange]
+
+        imgs = []
+        for i in batchRange:
+            if self.fast:
+                with self.env.begin() as txn:
+                    basename = Path(self.samples[i].filePath).basename()
+                    data = txn.get(basename.encode("ascii"))
+                    img = pickle.loads(data)
+            else:
+                img = cv2.imread(self.samples[i].filePath, cv2.IMREAD_GRAYSCALE)
+
+            imgs.append(preprocess(img, self.imgSize, self.dataAugmentation))
+
+        self.currIdx += self.batchSize
+        return Batch(gtTexts, imgs)
+
+
+if __name__ == '__main__':
+    dl = DataLoaderIAM('../data/', 50, (128, 32), 32)
+    dl.getNext()
diff --git a/src/Model.py b/src/Model.py
index 07e300fd880eb93ba401eb05b792d764ae8bd2fc..1b58629e6c6c4703d599395548ef719ae8f39524 100644
--- a/src/Model.py
+++ b/src/Model.py
@@ -1,6 +1,3 @@
-from __future__ import division
-from __future__ import print_function
-
 import numpy as np
 import os
 import sys
@@ -20,7 +17,6 @@ class Model:
     "minimalistic TF model for HTR"
 
     # model constants
-    batchSize = 50
     imgSize = (128, 32)
     maxTextLen = 32
 
@@ -45,10 +41,9 @@ class Model:
 
         # setup optimizer to train NN
         self.batchesTrained = 0
-        self.learningRate = tf.compat.v1.placeholder(tf.float32, shape=[])
         self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
         with tf.control_dependencies(self.update_ops):
-            self.optimizer = tf.compat.v1.train.RMSPropOptimizer(self.learningRate).minimize(self.loss)
+            self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss)
 
         # initialize TF
         (self.sess, self.saver) = self.setupTF()
@@ -221,12 +216,9 @@ class Model:
         "feed a batch into the NN to train it"
         numBatchElements = len(batch.imgs)
         sparse = self.toSparse(batch.gtTexts)
-        rate = 0.01 if self.batchesTrained < 10 else (
-            0.001 if self.batchesTrained < 10000 else 0.0001)  # decay learning rate
         evalList = [self.optimizer, self.loss]
-        feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse,
-                    self.seqLen: [Model.maxTextLen] * numBatchElements, self.learningRate: rate, self.is_train: True}
-        (_, lossVal) = self.sess.run(evalList, feedDict)
+        feedDict = {self.inputImgs: batch.imgs, self.gtTexts: sparse, self.seqLen: [Model.maxTextLen] * numBatchElements, self.is_train: True}
+        _, lossVal = self.sess.run(evalList, feedDict)
         self.batchesTrained += 1
         return lossVal
 
diff --git a/src/SamplePreprocessor.py b/src/SamplePreprocessor.py
index 1d571e9f33a858c58021ef681996b19c2c306424..5547eecf4ddc6523f70179a29f06d608a9f53ab6 100644
--- a/src/SamplePreprocessor.py
+++ b/src/SamplePreprocessor.py
@@ -1,6 +1,3 @@
-from __future__ import division
-from __future__ import print_function
-
 import random
 
 import cv2
@@ -14,9 +11,25 @@ def preprocess(img, imgSize, dataAugmentation=False):
     if img is None:
         img = np.zeros([imgSize[1], imgSize[0]])
 
+    img = img.astype(np.float)
+
     # increase dataset size by applying random stretches to the images
     if dataAugmentation:
-        stretch = (random.random() - 0.5)  # -0.5 .. +0.5
+        if random.random() < 0.25:
+            rand_odd = lambda: random.randint(1, 3) * 2 + 1
+            img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
+        if random.random() < 0.25:
+            img = cv2.dilate(img,np.ones((3,3)))
+        if random.random() < 0.25:
+            img = cv2.erode(img,np.ones((3,3)))
+        if random.random() < 0.5:
+            img = img * (0.25 + random.random() * 0.75)
+        if random.random() < 0.25:
+            img = np.clip(img + (np.random.random(img.shape)-0.5) * random.randint(1, 50), 0, 255)
+        if random.random() < 0.1:
+            img = 255 - img
+
+        stretch = random.random() - 0.5  # -0.5 .. +0.5
         wStretched = max(int(img.shape[1] * (1 + stretch)), 1)  # random width, but at least 1
         img = cv2.resize(img, (wStretched, img.shape[0]))  # stretch horizontally by factor 0.5 .. 1.5
 
@@ -29,16 +42,32 @@ def preprocess(img, imgSize, dataAugmentation=False):
     newSize = (max(min(wt, int(w / f)), 1),
                max(min(ht, int(h / f)), 1))  # scale according to f (result at least 1 and at most wt or ht)
     img = cv2.resize(img, newSize)
-    target = np.ones([ht, wt]) * 255
-    target[0:newSize[1], 0:newSize[0]] = img
+    target = np.ones([ht, wt]) * 127.5
+    
+    r_freedom = target.shape[0] - img.shape[0]
+    c_freedom = target.shape[1] - img.shape[1]
+    
+    if dataAugmentation:
+        r_off, c_off = random.randint(0, r_freedom), random.randint(0, c_freedom)
+    else:
+        r_off, c_off = r_freedom // 2, c_freedom // 2
+
+    target[r_off:img.shape[0]+r_off, c_off:img.shape[1]+c_off] = img
 
     # transpose for TF
     img = cv2.transpose(target)
 
-    # normalize
-    (m, s) = cv2.meanStdDev(img)
-    m = m[0][0]
-    s = s[0][0]
-    img = img - m
-    img = img / s if s > 0 else img
+    # convert to range [-1, 1]
+    img = img / 255 - 0.5
     return img
+
+
+if __name__ == '__main__':
+    import matplotlib.pyplot as plt
+    img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
+    img_aug = preprocess(img, (128, 32), True)
+    plt.subplot(121)
+    plt.imshow(img)
+    plt.subplot(122)
+    plt.imshow(cv2.transpose(img_aug))
+    plt.show()
\ No newline at end of file
diff --git a/src/analyze.py b/src/analyze.py
index d34a97b93cdf1944289fb94453f16ea27ef82e45..6658fd5f0bb99286d2e445d2c3e69e71c3667947 100644
--- a/src/analyze.py
+++ b/src/analyze.py
@@ -1,6 +1,3 @@
-from __future__ import division
-from __future__ import print_function
-
 import copy
 import math
 import pickle
diff --git a/src/create_lmdb.py b/src/create_lmdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..979571b103bbe2f27b2200409bb965ddf853d8b0
--- /dev/null
+++ b/src/create_lmdb.py
@@ -0,0 +1,27 @@
+import argparse
+import pickle
+
+import cv2
+import lmdb
+from path import Path
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--data_dir', type=Path, required=True)
+args = parser.parse_args()
+
+# 2GB is enough for IAM dataset
+assert not (args.data_dir / 'lmdb').exists()
+env = lmdb.open(str(args.data_dir / 'lmdb'), map_size=1024 * 1024 * 1024 * 2)
+
+# go over all png files
+fn_imgs = list((args.data_dir / 'img').walkfiles('*.png'))
+
+# and put the imgs into lmdb as pickled grayscale imgs
+with env.begin(write=True) as txn:
+    for i, fn_img in enumerate(fn_imgs):
+        print(i, len(fn_imgs))
+        img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
+        basename = fn_img.basename()
+        txn.put(basename.encode("ascii"), pickle.dumps(img))
+
+env.close()
diff --git a/src/main.py b/src/main.py
index 50f54e06f717cbe68b53139da4dbfad652a65108..9e62aa2473aa027f1c5b474e76f44c57491b0e26 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,21 +1,18 @@
-from __future__ import division
-from __future__ import print_function
-
 import argparse
 
 import cv2
 import editdistance
 
-from DataLoader import DataLoader, Batch
+from DataLoaderIAM import DataLoaderIAM, Batch
 from Model import Model, DecoderType
 from SamplePreprocessor import preprocess
+from path import Path
 
 
 class FilePaths:
     "filenames and paths to data"
     fnCharList = '../model/charList.txt'
     fnAccuracy = '../model/accuracy.txt'
-    fnTrain = '../data/'
     fnInfer = '../data/test.png'
     fnCorpus = '../data/corpus.txt'
 
@@ -25,7 +22,7 @@ def train(model, loader):
     epoch = 0  # number of training epochs since start
     bestCharErrorRate = float('inf')  # best valdiation character error rate
     noImprovementSince = 0  # number of epochs no improvement of character error rate occured
-    earlyStopping = 5  # stop training after this number of epochs without improvement
+    earlyStopping = 25  # stop training after this number of epochs without improvement
     while True:
         epoch += 1
         print('Epoch:', epoch)
@@ -37,7 +34,7 @@ def train(model, loader):
             iterInfo = loader.getIteratorInfo()
             batch = loader.getNext()
             loss = model.trainBatch(batch)
-            print('Batch:', iterInfo[0], '/', iterInfo[1], 'Loss:', loss)
+            print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}')
 
         # validate
         charErrorRate = validate(model, loader)
@@ -49,14 +46,14 @@ def train(model, loader):
             noImprovementSince = 0
             model.save()
             open(FilePaths.fnAccuracy, 'w').write(
-                'Validation character error rate of saved model: %f%%' % (charErrorRate * 100.0))
+                f'Validation character error rate of saved model: {charErrorRate * 100.0}%')
         else:
-            print('Character error rate not improved')
+            print(f'Character error rate not improved, best so far: {charErrorRate * 100.0}%')
             noImprovementSince += 1
 
         # stop training if no more improvement in the last x epochs
         if noImprovementSince >= earlyStopping:
-            print('No more improvement since %d epochs. Training stopped.' % earlyStopping)
+            print(f'No more improvement since {earlyStopping} epochs. Training stopped.')
             break
 
 
@@ -70,7 +67,7 @@ def validate(model, loader):
     numWordTotal = 0
     while loader.hasNext():
         iterInfo = loader.getIteratorInfo()
-        print('Batch:', iterInfo[0], '/', iterInfo[1])
+        print(f'Batch: {iterInfo[0]} / {iterInfo[1]}')
         batch = loader.getNext()
         (recognized, _) = model.inferBatch(batch)
 
@@ -87,7 +84,7 @@ def validate(model, loader):
     # print validation result
     charErrorRate = numCharErr / numCharTotal
     wordAccuracy = numWordOK / numWordTotal
-    print('Character error rate: %f%%. Word accuracy: %f%%.' % (charErrorRate * 100.0, wordAccuracy * 100.0))
+    print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.')
     return charErrorRate
 
 
@@ -96,8 +93,8 @@ def infer(model, fnImg):
     img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize)
     batch = Batch(None, [img])
     (recognized, probability) = model.inferBatch(batch, True)
-    print('Recognized:', '"' + recognized[0] + '"')
-    print('Probability:', probability[0])
+    print(f'Recognized: "{recognized[0]}"')
+    print(f'Probability: {probability[0]}')
 
 
 def main():
@@ -110,6 +107,9 @@ def main():
     parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding',
                         action='store_true')
     parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
+    parser.add_argument('--fast', help='use lmdb to load images', action='store_true')
+    parser.add_argument('--data_dir', help='directory containing IAM dataset', type=Path, required=False)
+    parser.add_argument('--batch_size', help='batch size', type=int, default=100)
 
     args = parser.parse_args()
 
@@ -122,7 +122,7 @@ def main():
     # train or validate on IAM dataset
     if args.train or args.validate:
         # load training data, create TF model
-        loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen)
+        loader = DataLoaderIAM(args.data_dir, args.batch_size, Model.imgSize, Model.maxTextLen, args.fast)
 
         # save characters of model for inference mode
         open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))