Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
SimpleHTR
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Fabian Mersch
SimpleHTR
Commits
6c76bca2
Commit
6c76bca2
authored
4 years ago
by
Harald Scheidl
Browse files
Options
Downloads
Patches
Plain Diff
started multi word training
parent
7edd1f86
No related branches found
No related tags found
No related merge requests found
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
src/dataloader_iam.py
+56
-15
56 additions, 15 deletions
src/dataloader_iam.py
src/main.py
+2
-2
2 additions, 2 deletions
src/main.py
src/model.py
+1
-1
1 addition, 1 deletion
src/model.py
src/preprocessor.py
+7
-4
7 additions, 4 deletions
src/preprocessor.py
with
66 additions
and
22 deletions
src/dataloader_iam.py
+
56
−
15
View file @
6c76bca2
...
...
@@ -7,7 +7,7 @@ import lmdb
import
numpy
as
np
from
path
import
Path
from
preprocess
import
preprocess
from
preprocess
or
import
preprocess
Sample
=
namedtuple
(
'
Sample
'
,
'
gt_text, file_path
'
)
Batch
=
namedtuple
(
'
Batch
'
,
'
gt_texts, imgs
'
)
...
...
@@ -16,7 +16,7 @@ Batch = namedtuple('Batch', 'gt_texts, imgs')
class
DataLoaderIAM
:
"
loads data which corresponds to IAM format, see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
"
def
__init__
(
self
,
data_dir
,
batch_size
,
img_size
,
max_text_len
,
fast
=
True
):
def
__init__
(
self
,
data_dir
,
batch_size
,
img_size
,
max_text_len
,
fast
=
True
,
multi_word_mode
=
False
):
"""
Loader for dataset.
"""
assert
data_dir
.
exists
()
...
...
@@ -25,6 +25,9 @@ class DataLoaderIAM:
if
fast
:
self
.
env
=
lmdb
.
open
(
str
(
data_dir
/
'
lmdb
'
),
readonly
=
True
)
self
.
max_text_len
=
max_text_len
self
.
multi_word_mode
=
multi_word_mode
self
.
data_augmentation
=
False
self
.
curr_idx
=
0
self
.
batch_size
=
batch_size
...
...
@@ -52,7 +55,7 @@ class DataLoaderIAM:
continue
# GT text are columns starting at 9
gt_text
=
self
.
truncate_label
(
'
'
.
join
(
line_split
[
8
:])
,
max_text_len
)
gt_text
=
'
'
.
join
(
line_split
[
8
:])
chars
=
chars
.
union
(
set
(
list
(
gt_text
)))
# put sample into list
...
...
@@ -71,10 +74,13 @@ class DataLoaderIAM:
self
.
train_set
()
# list of all chars in dataset
if
multi_word_mode
:
chars
.
add
(
'
'
)
self
.
char_list
=
sorted
(
list
(
chars
))
@staticmethod
def
truncate_label
(
text
,
max_text_len
):
def
_
truncate_label
(
text
,
max_text_len
):
"""
Function ctc_loss can
'
t compute loss if it cannot find a mapping between text label and input
labels. Repeat letters cost double because of the blank symbol needing to be inserted.
...
...
@@ -121,13 +127,7 @@ class DataLoaderIAM:
else
:
return
self
.
curr_idx
<
len
(
self
.
samples
)
# val set: allow last batch to be smaller
def
get_next
(
self
):
"
iterator
"
batch_range
=
range
(
self
.
curr_idx
,
min
(
self
.
curr_idx
+
self
.
batch_size
,
len
(
self
.
samples
)))
gt_texts
=
[
self
.
samples
[
i
].
gt_text
for
i
in
batch_range
]
imgs
=
[]
for
i
in
batch_range
:
def
_get_img
(
self
,
i
):
if
self
.
fast
:
with
self
.
env
.
begin
()
as
txn
:
basename
=
Path
(
self
.
samples
[
i
].
file_path
).
basename
()
...
...
@@ -136,7 +136,48 @@ class DataLoaderIAM:
else
:
img
=
cv2
.
imread
(
self
.
samples
[
i
].
file_path
,
cv2
.
IMREAD_GRAYSCALE
)
imgs
.
append
(
preprocess
(
img
,
self
.
img_size
,
data_augmentation
=
self
.
data_augmentation
))
return
img
@staticmethod
def
_simulate_multi_words
(
imgs
,
gt_texts
):
batch_size
=
len
(
imgs
)
res_imgs
=
[]
res_gt_texts
=
[]
word_sep_space
=
30
for
i
in
range
(
batch_size
):
j
=
(
i
+
1
)
%
batch_size
img_left
=
imgs
[
i
]
img_right
=
imgs
[
j
]
h
=
max
(
img_left
.
shape
[
0
],
img_right
.
shape
[
0
])
w
=
img_left
.
shape
[
1
]
+
img_right
.
shape
[
1
]
+
word_sep_space
target
=
np
.
ones
([
h
,
w
],
np
.
uint8
)
*
255
target
[
-
img_left
.
shape
[
0
]:,
:
img_left
.
shape
[
1
]]
=
img_left
target
[
-
img_right
.
shape
[
0
]:,
-
img_right
.
shape
[
1
]:]
=
img_right
res_imgs
.
append
(
target
)
res_gt_texts
.
append
(
gt_texts
[
i
]
+
'
'
+
gt_texts
[
j
])
return
res_imgs
,
res_gt_texts
def
get_next
(
self
):
"
Iterator.
"
batch_range
=
range
(
self
.
curr_idx
,
min
(
self
.
curr_idx
+
self
.
batch_size
,
len
(
self
.
samples
)))
imgs
=
[
self
.
_get_img
(
i
)
for
i
in
batch_range
]
gt_texts
=
[
self
.
samples
[
i
].
gt_text
for
i
in
batch_range
]
if
self
.
multi_word_mode
:
imgs
,
gt_texts
=
self
.
_simulate_multi_words
(
imgs
,
gt_texts
)
# apply data augmentation to images
imgs
=
[
preprocess
(
img
,
self
.
img_size
,
data_augmentation
=
self
.
data_augmentation
)
for
img
in
imgs
]
gt_texts
=
[
self
.
_truncate_label
(
gt_text
,
self
.
max_text_len
)
for
gt_text
in
gt_texts
]
self
.
curr_idx
+=
self
.
batch_size
return
Batch
(
gt_texts
,
imgs
)
This diff is collapsed.
Click to expand it.
src/main.py
+
2
−
2
View file @
6c76bca2
...
...
@@ -7,7 +7,7 @@ from path import Path
from
dataloader_iam
import
DataLoaderIAM
,
Batch
from
model
import
Model
,
DecoderType
from
preprocess
import
preprocess
from
preprocess
or
import
preprocess
class
FilePaths
:
...
...
@@ -130,7 +130,7 @@ def main():
# train or validate on IAM dataset
if
args
.
train
or
args
.
validate
:
# load training data, create TF model
loader
=
DataLoaderIAM
(
args
.
data_dir
,
args
.
batch_size
,
Model
.
img_size
,
Model
.
max_text_len
,
args
.
fast
)
loader
=
DataLoaderIAM
(
args
.
data_dir
,
args
.
batch_size
,
Model
.
img_size
,
Model
.
max_text_len
,
args
.
fast
,
True
)
# save characters of model for inference mode
open
(
FilePaths
.
fn_char_list
,
'
w
'
).
write
(
str
().
join
(
loader
.
char_list
))
...
...
This diff is collapsed.
Click to expand it.
src/model.py
+
1
−
1
View file @
6c76bca2
...
...
@@ -259,7 +259,7 @@ class Model:
if
self
.
dump
or
calc_probability
:
eval_list
.
append
(
self
.
ctc_in_3d_tbc
)
# sequence length
# sequence length
depends on input image size (model downsizes width by 4)
max_text_len
=
batch
.
imgs
[
0
].
shape
[
0
]
//
4
# dict containing all tensor fed into the model
...
...
This diff is collapsed.
Click to expand it.
src/preprocess.py
→
src/preprocess
or
.py
+
7
−
4
View file @
6c76bca2
...
...
@@ -3,6 +3,9 @@ import random
import
cv2
import
numpy
as
np
# TODO: change to class
# TODO: do multi-word simulation in here!
def
preprocess
(
img
,
img_size
,
dynamic_width
=
False
,
data_augmentation
=
False
):
"
put img into target img of size imgSize, transpose for TF and normalize gray-values
"
...
...
@@ -33,14 +36,14 @@ def preprocess(img, img_size, dynamic_width=False, data_augmentation=False):
wt
,
ht
=
img_size
h
,
w
=
img
.
shape
f
=
min
(
wt
/
w
,
ht
/
h
)
fx
=
f
*
np
.
random
.
uniform
(
0.75
,
1.
25
)
fy
=
f
*
np
.
random
.
uniform
(
0.75
,
1.
25
)
fx
=
f
*
np
.
random
.
uniform
(
0.75
,
1.
1
)
fy
=
f
*
np
.
random
.
uniform
(
0.75
,
1.
1
)
# random position around center
txc
=
(
wt
-
w
*
fx
)
/
2
tyc
=
(
ht
-
h
*
fy
)
/
2
freedom_x
=
max
((
wt
-
fx
*
w
)
/
2
,
0
)
+
wt
/
10
freedom_y
=
max
((
ht
-
fy
*
h
)
/
2
,
0
)
+
ht
/
10
freedom_x
=
max
((
wt
-
fx
*
w
)
/
2
,
0
)
freedom_y
=
max
((
ht
-
fy
*
h
)
/
2
,
0
)
tx
=
txc
+
np
.
random
.
uniform
(
-
freedom_x
,
freedom_x
)
ty
=
tyc
+
np
.
random
.
uniform
(
-
freedom_y
,
freedom_y
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment