Skip to content
Snippets Groups Projects
Commit e8249f7c authored by Harald Scheidl's avatar Harald Scheidl
Browse files

added dump option for NN output

parent f2f606da
No related branches found
No related tags found
No related merge requests found
......@@ -7,3 +7,4 @@ notes/
*.so
*.pyc
.idea/
dump/
\ No newline at end of file
......@@ -39,6 +39,7 @@ Tested with:
* `--validate`: validate the NN, details see below.
* `--beamsearch`: use vanilla beam search decoding (better, but slower) instead of best path decoding.
* `--wordbeamsearch`: use word beam search decoding (only outputs words contained in a dictionary) instead of best path decoding. This is a custom TF operation and must be compiled from source, more information see corresponding section below. It should **not** be used when training the NN.
* `--dump`: dumps the output of the NN to CSV file(s) saved in the `dump/` folder. Can be used as input for the [CTCDecoder](https://github.com/githubharald/CTCDecoder).
If neither `--train` nor `--validate` is specified, the NN infers the text from the test image (`data/test.png`).
Two examples: if you want to infer using beam search, execute `python main.py --beamsearch`, while you have to execute `python main.py --train --beamsearch` if you want to train the NN and do the validation using beam search.
......
......@@ -4,6 +4,7 @@ from __future__ import print_function
import sys
import numpy as np
import tensorflow as tf
import os
class DecoderType:
......@@ -20,8 +21,9 @@ class Model:
imgSize = (128, 32)
maxTextLen = 32
def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False):
def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False, dump=False):
"init model: add CNN, RNN and CTC and initialize TF"
self.dump = dump
self.charList = charList
self.decoderType = decoderType
self.mustRestore = mustRestore
......@@ -217,14 +219,35 @@ class Model:
return lossVal
def dumpNNOutput(self, rnnOutput):
"dump the output of the NN to CSV file(s)"
dumpDir = '../dump/'
if not os.path.isdir(dumpDir):
os.mkdir(dumpDir)
# iterate over all batch elements and create a CSV file for each one
maxT, maxB, maxC = rnnOutput.shape
for b in range(maxB):
csv = ''
for t in range(maxT):
for c in range(maxC):
csv += str(rnnOutput[t, b, c]) + ';'
csv += '\n'
fn = dumpDir + 'rnnOutput_'+str(b)+'.csv'
print('Write dump of NN to file: ' + fn)
with open(fn, 'w') as f:
f.write(csv)
def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False):
"feed a batch into the NN to recognize the texts"
# decode, optionally save RNN output
numBatchElements = len(batch.imgs)
evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else [])
evalRnnOutput = self.dump or calcProbability
evalList = [self.decoder] + ([self.ctcIn3dTBC] if evalRnnOutput else [])
feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict)
evalRes = self.sess.run(evalList, feedDict)
decoded = evalRes[0]
texts = self.decoderOutputToText(decoded, numBatchElements)
......@@ -237,6 +260,11 @@ class Model:
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals)
# dump the output of the NN to CSV file(s)
if self.dump:
self.dumpNNOutput(evalRes[1])
return (texts, probs)
......
......@@ -105,6 +105,8 @@ def main():
parser.add_argument('--validate', help='validate the NN', action='store_true')
parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true')
parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true')
parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')
args = parser.parse_args()
decoderType = DecoderType.BestPath
......@@ -135,7 +137,7 @@ def main():
# infer text on test image
else:
print(open(FilePaths.fnAccuracy).read())
model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True)
model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump)
infer(model, FilePaths.fnInfer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment