diff --git a/src/main.py b/src/main.py index d28dacca3971688089c02472b8139bce51baac2f..21024842707aa398cebaa9da329091ea57642f23 100644 --- a/src/main.py +++ b/src/main.py @@ -18,15 +18,16 @@ class FilePaths: fnCorpus = '../data/corpus.txt' -def write_summary(charErrorRates): +def write_summary(charErrorRates, wordAccuracies): with open(FilePaths.fnSummary, 'w') as f: - json.dump(charErrorRates, f) + json.dump({'charErrorRates': charErrorRates, 'wordAccuracies': wordAccuracies}, f) def train(model, loader): "train NN" epoch = 0 # number of training epochs since start summaryCharErrorRates = [] + summaryWordAccuracies = [] bestCharErrorRate = float('inf') # best valdiation character error rate noImprovementSince = 0 # number of epochs no improvement of character error rate occured earlyStopping = 25 # stop training after this number of epochs without improvement @@ -44,11 +45,12 @@ def train(model, loader): print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}') # validate - charErrorRate = validate(model, loader) + charErrorRate, wordAccuracy = validate(model, loader) # write summary summaryCharErrorRates.append(charErrorRate) - write_summary(summaryCharErrorRates) + summaryWordAccuracies.append(wordAccuracy) + write_summary(summaryCharErrorRates, summaryWordAccuracies) # if best validation accuracy so far, save model parameters if charErrorRate < bestCharErrorRate: @@ -94,7 +96,7 @@ def validate(model, loader): charErrorRate = numCharErr / numCharTotal wordAccuracy = numWordOK / numWordTotal print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.') - return charErrorRate + return charErrorRate, wordAccuracy def infer(model, fnImg):