Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
SimpleHTR
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Fabian Mersch
SimpleHTR
Commits
b9fc4e3a
Commit
b9fc4e3a
authored
6 years ago
by
Harald Scheidl
Browse files
Options
Downloads
Patches
Plain Diff
using character error rate instead of word accuracy
parent
2bacab23
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/main.py
+56
-78
56 additions, 78 deletions
src/main.py
with
56 additions
and
78 deletions
src/main.py
+
56
−
78
View file @
b9fc4e3a
import
sys
import
sys
import
argparse
import
argparse
import
cv2
import
cv2
import
editdistance
from
DataLoader
import
DataLoader
,
Batch
from
DataLoader
import
DataLoader
,
Batch
from
Model
import
Model
from
Model
import
Model
from
SamplePreprocessor
import
preprocess
from
SamplePreprocessor
import
preprocess
# filenames and paths to data
class
FilePaths
:
"
filenames and paths to data
"
fnCharList
=
'
../model/charList.txt
'
fnCharList
=
'
../model/charList.txt
'
fnAccuracy
=
'
../model/accuracy.txt
'
fnAccuracy
=
'
../model/accuracy.txt
'
fnTrain
=
'
../data/
'
fnTrain
=
'
../data/
'
fnInfer
=
'
../data/test.png
'
fnInfer
=
'
../data/test.png
'
useBeamSearch
=
False
def
train
(
filePath
):
def
train
(
model
,
loader
):
"
train NN
"
"
train NN
"
# load training data
loader
=
DataLoader
(
filePath
,
Model
.
batchSize
,
Model
.
imgSize
,
Model
.
maxTextLen
)
# create TF model
model
=
Model
(
loader
.
charList
,
useBeamSearch
)
# save characters of model for inference mode
open
(
fnCharList
,
'
w
'
).
write
(
str
().
join
(
loader
.
charList
))
# train forever
epoch
=
0
# number of training epochs since start
epoch
=
0
# number of training epochs since start
best
Accuracy
=
0.0
# best valdiation accuracy
best
CharErrorRate
=
float
(
'
inf
'
)
# best valdiation character error rate
noImprovementSince
=
0
# number of epochs no improvement of
accuracy
occured
noImprovementSince
=
0
# number of epochs no improvement of
character error rate
occured
earlyStopping
=
3
# stop training after this number of epochs without improvement
earlyStopping
=
3
# stop training after this number of epochs without improvement
while
True
:
while
True
:
epoch
+=
1
epoch
+=
1
...
@@ -44,36 +35,17 @@ def train(filePath):
...
@@ -44,36 +35,17 @@ def train(filePath):
print
(
'
Batch:
'
,
iterInfo
[
0
],
'
/
'
,
iterInfo
[
1
],
'
Loss:
'
,
loss
)
print
(
'
Batch:
'
,
iterInfo
[
0
],
'
/
'
,
iterInfo
[
1
],
'
Loss:
'
,
loss
)
# validate
# validate
print
(
'
Validate NN
'
)
charErrorRate
=
validate
(
model
,
loader
)
loader
.
validationSet
()
numOK
=
0
numTotal
=
0
while
loader
.
hasNext
():
iterInfo
=
loader
.
getIteratorInfo
()
print
(
'
Batch:
'
,
iterInfo
[
0
],
'
/
'
,
iterInfo
[
1
])
batch
=
loader
.
getNext
()
recognized
=
model
.
inferBatch
(
batch
)
print
(
'
Ground truth -> Recognized
'
)
for
i
in
range
(
len
(
recognized
)):
isOK
=
batch
.
gtTexts
[
i
]
==
recognized
[
i
]
print
(
'
[OK]
'
if
isOK
else
'
[ERR]
'
,
'"'
+
batch
.
gtTexts
[
i
]
+
'"'
,
'
->
'
,
'"'
+
recognized
[
i
]
+
'"'
)
numOK
+=
1
if
isOK
else
0
numTotal
+=
1
# print validation result
accuracy
=
numOK
/
numTotal
print
(
'
Correctly recognized words:
'
,
accuracy
*
100.0
,
'
%
'
)
# if best validation accuracy so far, save model parameters
# if best validation accuracy so far, save model parameters
if
accuracy
>
bestAccuracy
:
if
charErrorRate
<
bestCharErrorRate
:
print
(
'
Accuracy
improved, save model
'
)
print
(
'
Character error rate
improved, save model
'
)
best
Accuracy
=
accuracy
best
CharErrorRate
=
charErrorRate
noImprovementSince
=
0
noImprovementSince
=
0
model
.
save
()
model
.
save
()
open
(
fnAccuracy
,
'
w
'
).
write
(
'
Validation
accuracy of saved model:
'
+
str
(
accuracy
))
open
(
FilePaths
.
fnAccuracy
,
'
w
'
).
write
(
'
Validation
character error rate of saved model: %f%%.
'
%
(
charErrorRate
*
100.0
))
else
:
else
:
print
(
'
Accuracy
not improved
'
)
print
(
'
Character error rate
not improved
'
)
noImprovementSince
+=
1
noImprovementSince
+=
1
# stop training if no more improvement in the last x epochs
# stop training if no more improvement in the last x epochs
...
@@ -82,20 +54,11 @@ def train(filePath):
...
@@ -82,20 +54,11 @@ def train(filePath):
break
break
def
validate
(
filePath
):
def
validate
(
model
,
loader
):
"
validate NN
"
"
validate NN
"
# load training data
loader
=
DataLoader
(
filePath
,
Model
.
batchSize
,
Model
.
imgSize
,
Model
.
maxTextLen
)
# create TF model
model
=
Model
(
loader
.
charList
,
useBeamSearch
)
# save characters of model for inference mode
open
(
fnCharList
,
'
w
'
).
write
(
str
().
join
(
loader
.
charList
))
print
(
'
Validate NN
'
)
print
(
'
Validate NN
'
)
loader
.
validationSet
()
loader
.
validationSet
()
num
OK
=
0
num
Err
=
0
numTotal
=
0
numTotal
=
0
while
loader
.
hasNext
():
while
loader
.
hasNext
():
iterInfo
=
loader
.
getIteratorInfo
()
iterInfo
=
loader
.
getIteratorInfo
()
...
@@ -105,26 +68,27 @@ def validate(filePath):
...
@@ -105,26 +68,27 @@ def validate(filePath):
print
(
'
Ground truth -> Recognized
'
)
print
(
'
Ground truth -> Recognized
'
)
for
i
in
range
(
len
(
recognized
)):
for
i
in
range
(
len
(
recognized
)):
is
OK
=
batch
.
gtTexts
[
i
]
==
recognized
[
i
]
d
is
t
=
editdistance
.
eval
(
recognized
[
i
],
batch
.
gtTexts
[
i
])
print
(
'
[OK]
'
if
isOK
else
'
[ERR]
'
,
'"'
+
batch
.
gtTexts
[
i
]
+
'"'
,
'
->
'
,
'"'
+
recognized
[
i
]
+
'"'
)
numErr
+=
dist
num
OK
+=
1
if
isOK
else
0
num
Total
+=
len
(
batch
.
gtTexts
[
i
])
numTotal
+=
1
print
(
'
[OK]
'
if
dist
==
0
else
'
[ERR:%d]
'
%
dist
,
'"'
+
batch
.
gtTexts
[
i
]
+
'"'
,
'
->
'
,
'"'
+
recognized
[
i
]
+
'"'
)
# print validation result
# print validation result
accuracy
=
numOK
/
numTotal
charErrorRate
=
numErr
/
numTotal
print
(
'
Correctly recognized words:
'
,
accuracy
*
100.0
,
'
%
'
)
print
(
'
Character error rate: %f%%
'
%
(
charErrorRate
*
100.0
))
return
charErrorRate
def
infer
(
filePath
):
def
infer
(
model
,
fnImg
):
"
recognize text in image provided by file path
"
"
recognize text in image provided by file path
"
model
=
Model
(
open
(
fnCharList
).
read
(),
useBeamSearch
,
mustRestore
=
True
)
img
=
preprocess
(
cv2
.
imread
(
fnImg
,
cv2
.
IMREAD_GRAYSCALE
),
Model
.
imgSize
)
img
=
preprocess
(
cv2
.
imread
(
fnInfer
,
cv2
.
IMREAD_GRAYSCALE
),
Model
.
imgSize
)
batch
=
Batch
(
None
,
[
img
]
*
Model
.
batchSize
)
# fill all batch elements with same input image
batch
=
Batch
(
None
,
[
img
]
*
Model
.
batchSize
)
recognized
=
model
.
inferBatch
(
batch
)
# recognize text
recognized
=
model
.
inferBatch
(
batch
)
print
(
'
Recognized:
'
,
'"'
+
recognized
[
0
]
+
'"'
)
# all batch elements hold same result
print
(
'
Recognized:
'
,
'"'
+
recognized
[
0
]
+
'"'
)
if
__name__
==
'
__main__
'
:
def
main
():
"
main function
"
# optional command line args
# optional command line args
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"
--train
"
,
help
=
"
train the NN
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--train
"
,
help
=
"
train the NN
"
,
action
=
"
store_true
"
)
...
@@ -132,15 +96,29 @@ if __name__ == '__main__':
...
@@ -132,15 +96,29 @@ if __name__ == '__main__':
parser
.
add_argument
(
"
--beamsearch
"
,
help
=
"
use beam search instead of best path decoding
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--beamsearch
"
,
help
=
"
use beam search instead of best path decoding
"
,
action
=
"
store_true
"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# use beam search (better accuracy, but slower) instead of best path decoding
# train or validate on IAM dataset
if
args
.
beamsearch
:
if
args
.
train
or
args
.
validate
:
useBeamSearch
=
True
# load training data, create TF model
loader
=
DataLoader
(
FilePaths
.
fnTrain
,
Model
.
batchSize
,
Model
.
imgSize
,
Model
.
maxTextLen
)
# save characters of model for inference mode
open
(
FilePaths
.
fnCharList
,
'
w
'
).
write
(
str
().
join
(
loader
.
charList
))
# train or validat
e NN, or infer text on the text image
#
execute
train
ing
or validat
ion
if
args
.
train
:
if
args
.
train
:
train
(
fnTrain
)
model
=
Model
(
loader
.
charList
,
args
.
beamsearch
)
train
(
model
,
loader
)
elif
args
.
validate
:
elif
args
.
validate
:
validate
(
fnTrain
)
model
=
Model
(
loader
.
charList
,
args
.
beamsearch
,
mustRestore
=
True
)
validate
(
model
,
loader
)
# infer text on test image
else
:
else
:
infer
(
fnInfer
)
print
(
open
(
FilePaths
.
fnAccuracy
).
read
())
model
=
Model
(
open
(
FilePaths
.
fnCharList
).
read
(),
args
.
beamsearch
,
mustRestore
=
True
)
infer
(
model
,
FilePaths
.
fnInfer
)
if
__name__
==
'
__main__
'
:
main
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment