Skip to content
Snippets Groups Projects
Commit 00d7232d authored by Harald Scheidl's avatar Harald Scheidl
Browse files

log both word acc and cer for valset

parent 8c96a80e
Branches
No related tags found
No related merge requests found
...@@ -18,15 +18,16 @@ class FilePaths: ...@@ -18,15 +18,16 @@ class FilePaths:
fnCorpus = '../data/corpus.txt' fnCorpus = '../data/corpus.txt'
def write_summary(charErrorRates): def write_summary(charErrorRates, wordAccuracies):
with open(FilePaths.fnSummary, 'w') as f: with open(FilePaths.fnSummary, 'w') as f:
json.dump(charErrorRates, f) json.dump({'charErrorRates': charErrorRates, 'wordAccuracies': wordAccuracies}, f)
def train(model, loader): def train(model, loader):
"train NN" "train NN"
epoch = 0 # number of training epochs since start epoch = 0 # number of training epochs since start
summaryCharErrorRates = [] summaryCharErrorRates = []
summaryWordAccuracies = []
bestCharErrorRate = float('inf') # best valdiation character error rate bestCharErrorRate = float('inf') # best valdiation character error rate
noImprovementSince = 0 # number of epochs no improvement of character error rate occured noImprovementSince = 0 # number of epochs no improvement of character error rate occured
earlyStopping = 25 # stop training after this number of epochs without improvement earlyStopping = 25 # stop training after this number of epochs without improvement
...@@ -44,11 +45,12 @@ def train(model, loader): ...@@ -44,11 +45,12 @@ def train(model, loader):
print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}') print(f'Epoch: {epoch} Batch: {iterInfo[0]}/{iterInfo[1]} Loss: {loss}')
# validate # validate
charErrorRate = validate(model, loader) charErrorRate, wordAccuracy = validate(model, loader)
# write summary # write summary
summaryCharErrorRates.append(charErrorRate) summaryCharErrorRates.append(charErrorRate)
write_summary(summaryCharErrorRates) summaryWordAccuracies.append(wordAccuracy)
write_summary(summaryCharErrorRates, summaryWordAccuracies)
# if best validation accuracy so far, save model parameters # if best validation accuracy so far, save model parameters
if charErrorRate < bestCharErrorRate: if charErrorRate < bestCharErrorRate:
...@@ -94,7 +96,7 @@ def validate(model, loader): ...@@ -94,7 +96,7 @@ def validate(model, loader):
charErrorRate = numCharErr / numCharTotal charErrorRate = numCharErr / numCharTotal
wordAccuracy = numWordOK / numWordTotal wordAccuracy = numWordOK / numWordTotal
print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.') print(f'Character error rate: {charErrorRate * 100.0}%. Word accuracy: {wordAccuracy * 100.0}%.')
return charErrorRate return charErrorRate, wordAccuracy
def infer(model, fnImg): def infer(model, fnImg):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment