From 92b4f31c0c261f1143c62f2b7ef40d54d91c3231 Mon Sep 17 00:00:00 2001 From: Nishant <getrooted0019@hotmail.com> Date: Tue, 8 Dec 2020 22:47:57 +0530 Subject: [PATCH] Fixed inferBatch ValueError. --- src/Model.py | 6 +++--- src/main.py | 13 ++++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/Model.py b/src/Model.py index 1d12c81..e987acf 100644 --- a/src/Model.py +++ b/src/Model.py @@ -260,14 +260,14 @@ class Model: ctcInput = evalRes[1] evalList = self.lossPerElement feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} - #lossVals = self.sess.run(evalList, feedDict) - #probs = np.exp(-lossVals) + lossVals = self.sess.run(evalList, feedDict) + probs = np.exp(-lossVals) # dump the output of the NN to CSV file(s) if self.dump: self.dumpNNOutput(evalRes[1]) - return (texts) + return (texts, probs) def save(self): diff --git a/src/main.py b/src/main.py index 7645320..1e71283 100644 --- a/src/main.py +++ b/src/main.py @@ -40,7 +40,7 @@ def train(model, loader): # validate charErrorRate = validate(model, loader) - + # if best validation accuracy so far, save model parameters if charErrorRate < bestCharErrorRate: print('Character error rate improved, save model') @@ -71,8 +71,8 @@ def validate(model, loader): print('Batch:', iterInfo[0],'/', iterInfo[1]) batch = loader.getNext() (recognized, _) = model.inferBatch(batch) - - print('Ground truth -> Recognized') + + print('Ground truth -> Recognized') for i in range(len(recognized)): numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0 numWordTotal += 1 @@ -80,7 +80,7 @@ def validate(model, loader): numCharErr += dist numCharTotal += len(batch.gtTexts[i]) print('[OK]' if dist==0 else '[ERR:%d]' % dist,'"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"') - + # print validation result charErrorRate = numCharErr / numCharTotal wordAccuracy = numWordOK / numWordTotal @@ -115,14 +115,14 @@ def main(): elif args.wordbeamsearch: decoderType = DecoderType.WordBeamSearch - # train or validate on IAM dataset + # train or validate on IAM dataset if args.train or args.validate: # load training data, create TF model loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen) # save characters of model for inference mode open(FilePaths.fnCharList, 'w').write(str().join(loader.charList)) - + # save words contained in dataset into file open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords)) @@ -143,4 +143,3 @@ def main(): if __name__ == '__main__': main() - -- GitLab