Skip to content
Snippets Groups Projects
Commit b9fc4e3a authored by Harald Scheidl's avatar Harald Scheidl
Browse files

using character error rate instead of word accuracy

parent 2bacab23
No related branches found
No related tags found
No related merge requests found
import sys import sys
import argparse import argparse
import cv2 import cv2
import editdistance
from DataLoader import DataLoader, Batch from DataLoader import DataLoader, Batch
from Model import Model from Model import Model
from SamplePreprocessor import preprocess from SamplePreprocessor import preprocess
# filenames and paths to data class FilePaths:
"filenames and paths to data"
fnCharList = '../model/charList.txt' fnCharList = '../model/charList.txt'
fnAccuracy = '../model/accuracy.txt' fnAccuracy = '../model/accuracy.txt'
fnTrain = '../data/' fnTrain = '../data/'
fnInfer = '../data/test.png' fnInfer = '../data/test.png'
useBeamSearch = False
def train(filePath): def train(model, loader):
"train NN" "train NN"
# load training data
loader = DataLoader(filePath, Model.batchSize, Model.imgSize, Model.maxTextLen)
# create TF model
model = Model(loader.charList, useBeamSearch)
# save characters of model for inference mode
open(fnCharList, 'w').write(str().join(loader.charList))
# train forever
epoch = 0 # number of training epochs since start epoch = 0 # number of training epochs since start
bestAccuracy = 0.0 # best valdiation accuracy bestCharErrorRate = float('inf') # best valdiation character error rate
noImprovementSince = 0 # number of epochs no improvement of accuracy occured noImprovementSince = 0 # number of epochs no improvement of character error rate occured
earlyStopping = 3 # stop training after this number of epochs without improvement earlyStopping = 3 # stop training after this number of epochs without improvement
while True: while True:
epoch += 1 epoch += 1
...@@ -44,36 +35,17 @@ def train(filePath): ...@@ -44,36 +35,17 @@ def train(filePath):
print('Batch:', iterInfo[0],'/', iterInfo[1], 'Loss:', loss) print('Batch:', iterInfo[0],'/', iterInfo[1], 'Loss:', loss)
# validate # validate
print('Validate NN') charErrorRate = validate(model, loader)
loader.validationSet()
numOK = 0
numTotal = 0
while loader.hasNext():
iterInfo = loader.getIteratorInfo()
print('Batch:', iterInfo[0],'/', iterInfo[1])
batch = loader.getNext()
recognized = model.inferBatch(batch)
print('Ground truth -> Recognized')
for i in range(len(recognized)):
isOK = batch.gtTexts[i] == recognized[i]
print('[OK]' if isOK else '[ERR]','"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"')
numOK += 1 if isOK else 0
numTotal +=1
# print validation result
accuracy = numOK / numTotal
print('Correctly recognized words:', accuracy * 100.0, '%')
# if best validation accuracy so far, save model parameters # if best validation accuracy so far, save model parameters
if accuracy > bestAccuracy: if charErrorRate < bestCharErrorRate:
print('Accuracy improved, save model') print('Character error rate improved, save model')
bestAccuracy = accuracy bestCharErrorRate = charErrorRate
noImprovementSince = 0 noImprovementSince = 0
model.save() model.save()
open(fnAccuracy, 'w').write('Validation accuracy of saved model: '+str(accuracy)) open(FilePaths.fnAccuracy, 'w').write('Validation character error rate of saved model: %f%%.' % (charErrorRate*100.0))
else: else:
print('Accuracy not improved') print('Character error rate not improved')
noImprovementSince += 1 noImprovementSince += 1
# stop training if no more improvement in the last x epochs # stop training if no more improvement in the last x epochs
...@@ -82,20 +54,11 @@ def train(filePath): ...@@ -82,20 +54,11 @@ def train(filePath):
break break
def validate(filePath): def validate(model, loader):
"validate NN" "validate NN"
# load training data
loader = DataLoader(filePath, Model.batchSize, Model.imgSize, Model.maxTextLen)
# create TF model
model = Model(loader.charList, useBeamSearch)
# save characters of model for inference mode
open(fnCharList, 'w').write(str().join(loader.charList))
print('Validate NN') print('Validate NN')
loader.validationSet() loader.validationSet()
numOK = 0 numErr = 0
numTotal = 0 numTotal = 0
while loader.hasNext(): while loader.hasNext():
iterInfo = loader.getIteratorInfo() iterInfo = loader.getIteratorInfo()
...@@ -105,26 +68,27 @@ def validate(filePath): ...@@ -105,26 +68,27 @@ def validate(filePath):
print('Ground truth -> Recognized') print('Ground truth -> Recognized')
for i in range(len(recognized)): for i in range(len(recognized)):
isOK = batch.gtTexts[i] == recognized[i] dist = editdistance.eval(recognized[i], batch.gtTexts[i])
print('[OK]' if isOK else '[ERR]','"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"') numErr += dist
numOK += 1 if isOK else 0 numTotal += len(batch.gtTexts[i])
numTotal +=1 print('[OK]' if dist==0 else '[ERR:%d]' % dist,'"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"')
# print validation result # print validation result
accuracy = numOK / numTotal charErrorRate = numErr / numTotal
print('Correctly recognized words:', accuracy * 100.0, '%') print('Character error rate: %f%%' % (charErrorRate*100.0))
return charErrorRate
def infer(filePath): def infer(model, fnImg):
"recognize text in image provided by file path" "recognize text in image provided by file path"
model = Model(open(fnCharList).read(), useBeamSearch, mustRestore=True) img = preprocess(cv2.imread(fnImg, cv2.IMREAD_GRAYSCALE), Model.imgSize)
img = preprocess(cv2.imread(fnInfer, cv2.IMREAD_GRAYSCALE), Model.imgSize) batch = Batch(None, [img] * Model.batchSize) # fill all batch elements with same input image
batch = Batch(None, [img] * Model.batchSize) recognized = model.inferBatch(batch) # recognize text
recognized = model.inferBatch(batch) print('Recognized:', '"' + recognized[0] + '"') # all batch elements hold same result
print('Recognized:', '"' + recognized[0] + '"')
if __name__ == '__main__': def main():
"main function"
# optional command line args # optional command line args
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--train", help="train the NN", action="store_true") parser.add_argument("--train", help="train the NN", action="store_true")
...@@ -132,15 +96,29 @@ if __name__ == '__main__': ...@@ -132,15 +96,29 @@ if __name__ == '__main__':
parser.add_argument("--beamsearch", help="use beam search instead of best path decoding", action="store_true") parser.add_argument("--beamsearch", help="use beam search instead of best path decoding", action="store_true")
args = parser.parse_args() args = parser.parse_args()
# use beam search (better accuracy, but slower) instead of best path decoding # train or validate on IAM dataset
if args.beamsearch: if args.train or args.validate:
useBeamSearch = True # load training data, create TF model
loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen)
# save characters of model for inference mode
open(FilePaths.fnCharList, 'w').write(str().join(loader.charList))
# train or validate NN, or infer text on the text image # execute training or validation
if args.train: if args.train:
train(fnTrain) model = Model(loader.charList, args.beamsearch)
train(model, loader)
elif args.validate: elif args.validate:
validate(fnTrain) model = Model(loader.charList, args.beamsearch, mustRestore=True)
validate(model, loader)
# infer text on test image
else: else:
infer(fnInfer) print(open(FilePaths.fnAccuracy).read())
model = Model(open(FilePaths.fnCharList).read(), args.beamsearch, mustRestore=True)
infer(model, FilePaths.fnInfer)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment