diff --git a/README.md b/README.md index ee490a931e221c34a7f5a84e8ce2ab8a1408a3e6..2f01ea0b13e5c5d1b733c1b458dd314df03d3215 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,10 @@ The model takes **images of single words or text lines (multiple words) as input ## Run demo -* Download one of the trained models - * [Model trained on word images](https://www.dropbox.com/s/lod3gabgtuj0zzn/model.zip?dl=1): - only handle single words per image, but gives better results for IAM word dataset - * [Model trained on text line images](TODO): +* Download one of the pretrained models + * [Model trained on word images](https://www.dropbox.com/s/mya8hw6jyzqm0a3/word-model.zip?dl=1): + only handles single words per image, but gives better results on the IAM word dataset + * [Model trained on text line images](https://www.dropbox.com/s/7xwkcilho10rthn/line-model.zip?dl=1): can handle multiple words in one image * Put the contents of the downloaded zip-file into the `model` directory of the repository * Go to the `src` directory @@ -30,28 +30,27 @@ The input images, and the expected outputs are shown below when the text line mo  ``` > python main.py -Init with stored values from ../model/snapshot-15 +Init with stored values from ../model/snapshot-13 Recognized: "word" -Probability: 0.9741360545158386 +Probability: 0.9806370139122009 ```  ``` > python main.py --img_file ../data/line.png -Init with stored values from ../model/snapshot-15 +Init with stored values from ../model/snapshot-13 Recognized: "or work on line level" -Probability: 0.8010453581809998 +Probability: 0.6674373149871826 ``` - ## Command line arguments * `--mode`: select between "train", "validate" and "infer". Defaults to "infer". * `--decoder`: select from CTC decoders "bestpath", "beamsearch" and "wordbeamsearch". Defaults to "bestpath". For option "wordbeamsearch" see details below. * `--batch_size`: batch size. * `--data_dir`: directory containing IAM dataset (with subdirectories `img` and `gt`). -* `--fast`: use LMDB to load images (faster than loading image files from disk). -* `--line_mode`': train reading text lines instead of single words +* `--fast`: use LMDB to load images faster. +* `--line_mode`: train reading text lines instead of single words. * `--img_file`: image that is used for inference. * `--dump`: dumps the output of the NN to CSV file(s) saved in the `dump` folder. Can be used as input for the [CTCDecoder](https://github.com/githubharald/CTCDecoder). @@ -75,8 +74,9 @@ Further, the manually created list of word-characters can be found in the file ` Beam width is set to 50 to conform with the beam width of vanilla beam search decoding. -## Train model with IAM dataset +## Train model on IAM dataset +### Prepare dataset Follow these instructions to get the IAM dataset: * Register for free at this [website](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database) @@ -86,7 +86,7 @@ Follow these instructions to get the IAM dataset: * Put `words.txt` into the `gt` directory * Put the content (directories `a01`, `a02`, ...) of `words.tgz` into the `img` directory -### Start the training +### Run training * Delete files from `model` directory if you want to train from scratch * Go to the `src` directory and execute `python main.py --mode train --data_dir path/to/IAM` @@ -95,32 +95,35 @@ Follow these instructions to get the IAM dataset: the model is trained on text line images created by combining multiple word images into one * Training stops after a fixed number of epochs without improvement +The pretrained word model was trained with this command on a GTX 1050 Ti: +``` +python main.py --mode train --fast --data_dir path/to/iam --batch_size 500 --early_stopping 15 +``` + +And the line model with: +``` +python main.py --mode train --fast --data_dir path/to/iam --batch_size 250 --early_stopping 10 +``` + + ### Fast image loading Loading and decoding the png image files from the disk is the bottleneck even when using only a small GPU. The database LMDB is used to speed up image loading: -* Go to the `src` directory and run `create_lmdb.py --data_dir path/to/IAM` with the IAM data directory specified +* Go to the `src` directory and run `create_lmdb.py --data_dir path/to/iam` with the IAM data directory specified * A subfolder `lmdb` is created in the IAM data directory containing the LMDB files * When training the model, add the command line option `--fast` The dataset should be located on an SSD drive. -Using the `--fast` option and a GTX 1050 Ti training single words takes around 3h with a batch size of 500. -Training text lines takes a bit longer. +Using the `--fast` option and a GTX 1050 Ti training on single words takes around 3h with a batch size of 500. +Training on text lines takes a bit longer. ## Information about model The model is a stripped-down version of the HTR system I implemented for [my thesis]((https://repositum.tuwien.ac.at/obvutwhs/download/pdf/2874742)). -What remains is what I think is the bare minimum to recognize text with an acceptable accuracy. +What remains is the bare minimum to recognize text with an acceptable accuracy. It consists of 5 CNN layers, 2 RNN (LSTM) layers and the CTC loss and decoding layer. -The illustration below gives an overview of the NN (green: operations, pink: data flowing through NN) and here follows a short description: - -* The input image is a gray-value image and has a size of 128x32 - (in training mode the width is fixed, while in inference mode there is no restriction other than being a multiple of 4) -* 5 CNN layers map the input image to a feature sequence of size 32x256 -* 2 LSTM layers with 256 units propagate information through the sequence and map the sequence to a matrix of size 32x80. Each matrix-element represents a score for one of the 80 characters at one of the 32 time-steps -* The CTC layer either calculates the loss value given the matrix and the ground-truth text (when training), or it decodes the matrix to the final text with best path decoding or beam search decoding (when inferring) - - +For more details see this [Medium article](https://towardsdatascience.com/2326a3487cd5). ## References diff --git a/src/dataloader_iam.py b/src/dataloader_iam.py index e4a5f438725a2879ba97167d389661b5163c0825..5aa8db7c7ca62fd6338bb7f120b97a6158c400d2 100644 --- a/src/dataloader_iam.py +++ b/src/dataloader_iam.py @@ -13,7 +13,10 @@ Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size') class DataLoaderIAM: - "loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database" + """ + Loads data which corresponds to IAM format, + see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + """ def __init__(self, data_dir: Path, @@ -46,8 +49,10 @@ class DataLoaderIAM: # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png file_name_split = line_split[0].split('-') - file_name = data_dir / 'img' / file_name_split[0] / f'{file_name_split[0]}-{file_name_split[1]}' / \ - line_split[0] + '.png' + file_name_subdir1 = file_name_split[0] + file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}' + file_base_name = line_split[0] + '.png' + file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name if line_split[0] in bad_samples_reference: print('Ignoring known broken image:', file_name) @@ -84,7 +89,7 @@ class DataLoaderIAM: self.curr_set = 'train' def validation_set(self) -> None: - "switch to validation set" + """Switch to validation set.""" self.data_augmentation = False self.curr_idx = 0 self.samples = self.validation_samples @@ -100,7 +105,7 @@ class DataLoaderIAM: return curr_batch, num_batches def has_next(self) -> bool: - "iterator" + """Is there a next element?""" if self.curr_set == 'train': return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches else: @@ -118,7 +123,7 @@ class DataLoaderIAM: return img def get_next(self) -> Batch: - "Iterator." + """Get next element.""" batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) imgs = [self._get_img(i) for i in batch_range] diff --git a/src/main.py b/src/main.py index ef0f84c0c0a4f024e40889854378c8feb0fea1c2..3b3ac436e0a584a0a3314990feaf382e990c7582 100644 --- a/src/main.py +++ b/src/main.py @@ -19,21 +19,27 @@ class FilePaths: def get_img_height() -> int: + """Fixed height for NN.""" return 32 def get_img_size(line_mode: bool = False) -> Tuple[int, int]: + """Height is fixed for NN, width is set according to training mode (single words or text lines).""" if line_mode: return 256, get_img_height() return 128, get_img_height() def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -> None: + """Writes training summary file for NN.""" with open(FilePaths.fn_summary, 'w') as f: json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) -def train(model: Model, loader: DataLoaderIAM, line_mode: bool) -> None: +def train(model: Model, + loader: DataLoaderIAM, + line_mode: bool, + early_stopping: int = 25) -> None: """Trains NN.""" epoch = 0 # number of training epochs since start summary_char_error_rates = [] @@ -41,7 +47,7 @@ def train(model: Model, loader: DataLoaderIAM, line_mode: bool) -> None: preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode) best_char_error_rate = float('inf') # best valdiation character error rate no_improvement_since = 0 # number of epochs no improvement of character error rate occurred - early_stopping = 25 # stop training after this number of epochs without improvement + # stop training after this number of epochs without improvement while True: epoch += 1 print('Epoch:', epoch) @@ -115,8 +121,12 @@ def validate(model: Model, loader: DataLoaderIAM, line_mode: bool) -> Tuple[floa def infer(model: Model, fn_img: Path) -> None: """Recognizes text in image provided by file path.""" + img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE) + assert img is not None + preprocessor = Preprocessor(get_img_size(), dynamic_width=True, padding=16) - img = preprocessor.process_img(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)) + img = preprocessor.process_img(img) + batch = Batch([img], None, 1) recognized, probability = model.infer_batch(batch, True) print(f'Recognized: "{recognized[0]}"') @@ -134,6 +144,7 @@ def main(): parser.add_argument('--fast', help='Load samples from LMDB.', action='store_true') parser.add_argument('--line_mode', help='Train to read text lines instead of single words.', action='store_true') parser.add_argument('--img_file', help='Image used for inference.', type=Path, default='../data/word.png') + parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25) parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true') args = parser.parse_args() @@ -162,7 +173,7 @@ def main(): # execute training or validation if args.mode == 'train': model = Model(char_list, decoder_type) - train(model, loader, args.line_mode) + train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping) elif args.mode == 'validate': model = Model(char_list, decoder_type, must_restore=True) validate(model, loader, args.line_mode) diff --git a/src/model.py b/src/model.py index 773ef33779a1770219c7a13580eb3c6293ed339e..65d5c17a39494519c569adba49cce8aec6144d1d 100644 --- a/src/model.py +++ b/src/model.py @@ -145,7 +145,7 @@ class Model: # the input to the decoder must have softmax already applied self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2) - def setup_tf(self) -> None: + def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]: """Initialize TF.""" print('Python: ' + sys.version) print('Tensorflow: ' + tf.__version__) @@ -170,7 +170,7 @@ class Model: return sess, saver - def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], Tuple[int, int]]: + def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]: """Put ground truth texts into sparse tensor for ctc_loss.""" indices = [] values = [] diff --git a/src/preprocessor.py b/src/preprocessor.py index af1a7c11f03aaea3f2ad05b254a8e96145149922..8c956ea7c5d86c4e1597746a06ce8929a9f35fd8 100644 --- a/src/preprocessor.py +++ b/src/preprocessor.py @@ -100,7 +100,8 @@ class Preprocessor: if self.data_augmentation: # photometric data augmentation if random.random() < 0.25: - rand_odd = lambda: random.randint(1, 3) * 2 + 1 + def rand_odd(): + return random.randint(1, 3) * 2 + 1 img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) if random.random() < 0.25: img = cv2.dilate(img, np.ones((3, 3))) @@ -174,7 +175,7 @@ class Preprocessor: return Batch(res_imgs, res_gt_texts, batch.batch_size) -if __name__ == '__main__': +def main(): import matplotlib.pyplot as plt img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE) @@ -184,3 +185,7 @@ if __name__ == '__main__': plt.subplot(122) plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1) plt.show() + + +if __name__ == '__main__': + main()