Skip to content
Snippets Groups Projects
Commit ab6220d6 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

readme, docstrings, and other small changes

parent 2b0b5484
No related branches found
No related tags found
No related merge requests found
......@@ -14,10 +14,10 @@ The model takes **images of single words or text lines (multiple words) as input
## Run demo
* Download one of the trained models
* [Model trained on word images](https://www.dropbox.com/s/lod3gabgtuj0zzn/model.zip?dl=1):
only handle single words per image, but gives better results for IAM word dataset
* [Model trained on text line images](TODO):
* Download one of the pretrained models
* [Model trained on word images](https://www.dropbox.com/s/mya8hw6jyzqm0a3/word-model.zip?dl=1):
only handles single words per image, but gives better results on the IAM word dataset
* [Model trained on text line images](https://www.dropbox.com/s/7xwkcilho10rthn/line-model.zip?dl=1):
can handle multiple words in one image
* Put the contents of the downloaded zip-file into the `model` directory of the repository
* Go to the `src` directory
......@@ -30,28 +30,27 @@ The input images, and the expected outputs are shown below when the text line mo
![test](./data/word.png)
```
> python main.py
Init with stored values from ../model/snapshot-15
Init with stored values from ../model/snapshot-13
Recognized: "word"
Probability: 0.9741360545158386
Probability: 0.9806370139122009
```
![test](./data/line.png)
```
> python main.py --img_file ../data/line.png
Init with stored values from ../model/snapshot-15
Init with stored values from ../model/snapshot-13
Recognized: "or work on line level"
Probability: 0.8010453581809998
Probability: 0.6674373149871826
```
## Command line arguments
* `--mode`: select between "train", "validate" and "infer". Defaults to "infer".
* `--decoder`: select from CTC decoders "bestpath", "beamsearch" and "wordbeamsearch". Defaults to "bestpath". For option "wordbeamsearch" see details below.
* `--batch_size`: batch size.
* `--data_dir`: directory containing IAM dataset (with subdirectories `img` and `gt`).
* `--fast`: use LMDB to load images (faster than loading image files from disk).
* `--line_mode`': train reading text lines instead of single words
* `--fast`: use LMDB to load images faster.
* `--line_mode`: train reading text lines instead of single words.
* `--img_file`: image that is used for inference.
* `--dump`: dumps the output of the NN to CSV file(s) saved in the `dump` folder. Can be used as input for the [CTCDecoder](https://github.com/githubharald/CTCDecoder).
......@@ -75,8 +74,9 @@ Further, the manually created list of word-characters can be found in the file `
Beam width is set to 50 to conform with the beam width of vanilla beam search decoding.
## Train model with IAM dataset
## Train model on IAM dataset
### Prepare dataset
Follow these instructions to get the IAM dataset:
* Register for free at this [website](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database)
......@@ -86,7 +86,7 @@ Follow these instructions to get the IAM dataset:
* Put `words.txt` into the `gt` directory
* Put the content (directories `a01`, `a02`, ...) of `words.tgz` into the `img` directory
### Start the training
### Run training
* Delete files from `model` directory if you want to train from scratch
* Go to the `src` directory and execute `python main.py --mode train --data_dir path/to/IAM`
......@@ -95,32 +95,35 @@ Follow these instructions to get the IAM dataset:
the model is trained on text line images created by combining multiple word images into one
* Training stops after a fixed number of epochs without improvement
The pretrained word model was trained with this command on a GTX 1050 Ti:
```
python main.py --mode train --fast --data_dir path/to/iam --batch_size 500 --early_stopping 15
```
And the line model with:
```
python main.py --mode train --fast --data_dir path/to/iam --batch_size 250 --early_stopping 10
```
### Fast image loading
Loading and decoding the png image files from the disk is the bottleneck even when using only a small GPU.
The database LMDB is used to speed up image loading:
* Go to the `src` directory and run `create_lmdb.py --data_dir path/to/IAM` with the IAM data directory specified
* Go to the `src` directory and run `create_lmdb.py --data_dir path/to/iam` with the IAM data directory specified
* A subfolder `lmdb` is created in the IAM data directory containing the LMDB files
* When training the model, add the command line option `--fast`
The dataset should be located on an SSD drive.
Using the `--fast` option and a GTX 1050 Ti training single words takes around 3h with a batch size of 500.
Training text lines takes a bit longer.
Using the `--fast` option and a GTX 1050 Ti training on single words takes around 3h with a batch size of 500.
Training on text lines takes a bit longer.
## Information about model
The model is a stripped-down version of the HTR system I implemented for [my thesis]((https://repositum.tuwien.ac.at/obvutwhs/download/pdf/2874742)).
What remains is what I think is the bare minimum to recognize text with an acceptable accuracy.
What remains is the bare minimum to recognize text with an acceptable accuracy.
It consists of 5 CNN layers, 2 RNN (LSTM) layers and the CTC loss and decoding layer.
The illustration below gives an overview of the NN (green: operations, pink: data flowing through NN) and here follows a short description:
* The input image is a gray-value image and has a size of 128x32
(in training mode the width is fixed, while in inference mode there is no restriction other than being a multiple of 4)
* 5 CNN layers map the input image to a feature sequence of size 32x256
* 2 LSTM layers with 256 units propagate information through the sequence and map the sequence to a matrix of size 32x80. Each matrix-element represents a score for one of the 80 characters at one of the 32 time-steps
* The CTC layer either calculates the loss value given the matrix and the ground-truth text (when training), or it decodes the matrix to the final text with best path decoding or beam search decoding (when inferring)
![nn_overview](./doc/nn_overview.png)
For more details see this [Medium article](https://towardsdatascience.com/2326a3487cd5).
## References
......
......@@ -13,7 +13,10 @@ Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')
class DataLoaderIAM:
"loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database"
"""
Loads data which corresponds to IAM format,
see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
"""
def __init__(self,
data_dir: Path,
......@@ -46,8 +49,10 @@ class DataLoaderIAM:
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
file_name_split = line_split[0].split('-')
file_name = data_dir / 'img' / file_name_split[0] / f'{file_name_split[0]}-{file_name_split[1]}' / \
line_split[0] + '.png'
file_name_subdir1 = file_name_split[0]
file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}'
file_base_name = line_split[0] + '.png'
file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name
if line_split[0] in bad_samples_reference:
print('Ignoring known broken image:', file_name)
......@@ -84,7 +89,7 @@ class DataLoaderIAM:
self.curr_set = 'train'
def validation_set(self) -> None:
"switch to validation set"
"""Switch to validation set."""
self.data_augmentation = False
self.curr_idx = 0
self.samples = self.validation_samples
......@@ -100,7 +105,7 @@ class DataLoaderIAM:
return curr_batch, num_batches
def has_next(self) -> bool:
"iterator"
"""Is there a next element?"""
if self.curr_set == 'train':
return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches
else:
......@@ -118,7 +123,7 @@ class DataLoaderIAM:
return img
def get_next(self) -> Batch:
"Iterator."
"""Get next element."""
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
imgs = [self._get_img(i) for i in batch_range]
......
......@@ -19,21 +19,27 @@ class FilePaths:
def get_img_height() -> int:
"""Fixed height for NN."""
return 32
def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
"""Height is fixed for NN, width is set according to training mode (single words or text lines)."""
if line_mode:
return 256, get_img_height()
return 128, get_img_height()
def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -> None:
"""Writes training summary file for NN."""
with open(FilePaths.fn_summary, 'w') as f:
json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
def train(model: Model, loader: DataLoaderIAM, line_mode: bool) -> None:
def train(model: Model,
loader: DataLoaderIAM,
line_mode: bool,
early_stopping: int = 25) -> None:
"""Trains NN."""
epoch = 0 # number of training epochs since start
summary_char_error_rates = []
......@@ -41,7 +47,7 @@ def train(model: Model, loader: DataLoaderIAM, line_mode: bool) -> None:
preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode)
best_char_error_rate = float('inf') # best valdiation character error rate
no_improvement_since = 0 # number of epochs no improvement of character error rate occurred
early_stopping = 25 # stop training after this number of epochs without improvement
# stop training after this number of epochs without improvement
while True:
epoch += 1
print('Epoch:', epoch)
......@@ -115,8 +121,12 @@ def validate(model: Model, loader: DataLoaderIAM, line_mode: bool) -> Tuple[floa
def infer(model: Model, fn_img: Path) -> None:
"""Recognizes text in image provided by file path."""
img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
assert img is not None
preprocessor = Preprocessor(get_img_size(), dynamic_width=True, padding=16)
img = preprocessor.process_img(cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE))
img = preprocessor.process_img(img)
batch = Batch([img], None, 1)
recognized, probability = model.infer_batch(batch, True)
print(f'Recognized: "{recognized[0]}"')
......@@ -134,6 +144,7 @@ def main():
parser.add_argument('--fast', help='Load samples from LMDB.', action='store_true')
parser.add_argument('--line_mode', help='Train to read text lines instead of single words.', action='store_true')
parser.add_argument('--img_file', help='Image used for inference.', type=Path, default='../data/word.png')
parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25)
parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true')
args = parser.parse_args()
......@@ -162,7 +173,7 @@ def main():
# execute training or validation
if args.mode == 'train':
model = Model(char_list, decoder_type)
train(model, loader, args.line_mode)
train(model, loader, line_mode=args.line_mode, early_stopping=args.early_stopping)
elif args.mode == 'validate':
model = Model(char_list, decoder_type, must_restore=True)
validate(model, loader, args.line_mode)
......
......@@ -145,7 +145,7 @@ class Model:
# the input to the decoder must have softmax already applied
self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2)
def setup_tf(self) -> None:
def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]:
"""Initialize TF."""
print('Python: ' + sys.version)
print('Tensorflow: ' + tf.__version__)
......@@ -170,7 +170,7 @@ class Model:
return sess, saver
def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], Tuple[int, int]]:
def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]:
"""Put ground truth texts into sparse tensor for ctc_loss."""
indices = []
values = []
......
......@@ -100,7 +100,8 @@ class Preprocessor:
if self.data_augmentation:
# photometric data augmentation
if random.random() < 0.25:
rand_odd = lambda: random.randint(1, 3) * 2 + 1
def rand_odd():
return random.randint(1, 3) * 2 + 1
img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
if random.random() < 0.25:
img = cv2.dilate(img, np.ones((3, 3)))
......@@ -174,7 +175,7 @@ class Preprocessor:
return Batch(res_imgs, res_gt_texts, batch.batch_size)
if __name__ == '__main__':
def main():
import matplotlib.pyplot as plt
img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
......@@ -184,3 +185,7 @@ if __name__ == '__main__':
plt.subplot(122)
plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
plt.show()
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment