Skip to content
Snippets Groups Projects
Commit 998cc601 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

decaying learning rate, changed early stopping parameter

parent 0f0947b8
Branches
No related tags found
No related merge requests found
......@@ -28,7 +28,9 @@ class Model:
(self.loss, self.decoder) = self.setupCTC(rnnOut3d)
# optimizer for NN parameters
self.optimizer = tf.train.RMSPropOptimizer(0.001).minimize(self.loss)
self.batchesTrained = 0
self.learningRate = tf.placeholder(tf.float32, shape=[])
self.optimizer = tf.train.RMSPropOptimizer(self.learningRate).minimize(self.loss)
# initialize TF
(self.sess, self.saver) = self.setupTF()
......@@ -162,7 +164,9 @@ class Model:
def trainBatch(self, batch):
"feed a batch into the NN to train it"
sparse = self.toSparse(batch.gtTexts)
(_, lossVal) = self.sess.run([self.optimizer, self.loss], { self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * Model.batchSize } )
rate = 0.01 if self.batchesTrained < 10 else (0.001 if self.batchesTrained < 10000 else 0.0001) # decay learning rate
(_, lossVal) = self.sess.run([self.optimizer, self.loss], { self.inputImgs : batch.imgs, self.gtTexts : sparse , self.seqLen : [Model.maxTextLen] * Model.batchSize, self.learningRate : rate} )
self.batchesTrained += 1
return lossVal
......
......@@ -20,7 +20,7 @@ def train(model, loader):
epoch = 0 # number of training epochs since start
bestCharErrorRate = float('inf') # best valdiation character error rate
noImprovementSince = 0 # number of epochs no improvement of character error rate occured
earlyStopping = 3 # stop training after this number of epochs without improvement
earlyStopping = 5 # stop training after this number of epochs without improvement
while True:
epoch += 1
print('Epoch:', epoch)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment