diff --git a/src/Model.py b/src/Model.py index 1d12c81aaa490f34c3a49d9bdb7bf954ffed5643..e987acf2b0ae2a436e4b6dda8ec69f6b66b6910f 100644 --- a/src/Model.py +++ b/src/Model.py @@ -260,14 +260,14 @@ class Model: ctcInput = evalRes[1] evalList = self.lossPerElement feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} - #lossVals = self.sess.run(evalList, feedDict) - #probs = np.exp(-lossVals) + lossVals = self.sess.run(evalList, feedDict) + probs = np.exp(-lossVals) # dump the output of the NN to CSV file(s) if self.dump: self.dumpNNOutput(evalRes[1]) - return (texts) + return (texts, probs) def save(self): diff --git a/src/main.py b/src/main.py index 764532068023b035920293900b09db68c6652303..1e712830526d579f001113c4e4a2b3a6de8318ca 100644 --- a/src/main.py +++ b/src/main.py @@ -40,7 +40,7 @@ def train(model, loader): # validate charErrorRate = validate(model, loader) - + # if best validation accuracy so far, save model parameters if charErrorRate < bestCharErrorRate: print('Character error rate improved, save model') @@ -71,8 +71,8 @@ def validate(model, loader): print('Batch:', iterInfo[0],'/', iterInfo[1]) batch = loader.getNext() (recognized, _) = model.inferBatch(batch) - - print('Ground truth -> Recognized') + + print('Ground truth -> Recognized') for i in range(len(recognized)): numWordOK += 1 if batch.gtTexts[i] == recognized[i] else 0 numWordTotal += 1 @@ -80,7 +80,7 @@ def validate(model, loader): numCharErr += dist numCharTotal += len(batch.gtTexts[i]) print('[OK]' if dist==0 else '[ERR:%d]' % dist,'"' + batch.gtTexts[i] + '"', '->', '"' + recognized[i] + '"') - + # print validation result charErrorRate = numCharErr / numCharTotal wordAccuracy = numWordOK / numWordTotal @@ -115,14 +115,14 @@ def main(): elif args.wordbeamsearch: decoderType = DecoderType.WordBeamSearch - # train or validate on IAM dataset + # train or validate on IAM dataset if args.train or args.validate: # load training data, create TF model loader = DataLoader(FilePaths.fnTrain, Model.batchSize, Model.imgSize, Model.maxTextLen) # save characters of model for inference mode open(FilePaths.fnCharList, 'w').write(str().join(loader.charList)) - + # save words contained in dataset into file open(FilePaths.fnCorpus, 'w').write(str(' ').join(loader.trainWords + loader.validationWords)) @@ -143,4 +143,3 @@ def main(): if __name__ == '__main__': main() -