Skip to content
Snippets Groups Projects
Commit 69410d8a authored by Nishant's avatar Nishant
Browse files

Migrated from Tensorflow v1 to Tensorflow v2.

parent 97c2512f
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,8 @@ import numpy as np ...@@ -6,6 +6,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import os import os
# Disable eagre
tf.compat.v1.disable_eager_execution()
class DecoderType: class DecoderType:
BestPath = 0 BestPath = 0
...@@ -17,7 +19,7 @@ class Model: ...@@ -17,7 +19,7 @@ class Model:
"minimalistic TF model for HTR" "minimalistic TF model for HTR"
# model constants # model constants
batchSize = 50 batchSize = 32
imgSize = (128, 32) imgSize = (128, 32)
maxTextLen = 32 maxTextLen = 32
...@@ -30,10 +32,10 @@ class Model: ...@@ -30,10 +32,10 @@ class Model:
self.snapID = 0 self.snapID = 0
# Whether to use normalization over a batch or a population # Whether to use normalization over a batch or a population
self.is_train = tf.placeholder(tf.bool, name='is_train') self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train')
# input image batch # input image batch
self.inputImgs = tf.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1])) self.inputImgs = tf.compat.v1.placeholder(tf.float32, shape=(None, Model.imgSize[0], Model.imgSize[1]))
# setup CNN, RNN and CTC # setup CNN, RNN and CTC
self.setupCNN() self.setupCNN()
...@@ -42,10 +44,10 @@ class Model: ...@@ -42,10 +44,10 @@ class Model:
# setup optimizer to train NN # setup optimizer to train NN
self.batchesTrained = 0 self.batchesTrained = 0
self.learningRate = tf.placeholder(tf.float32, shape=[]) self.learningRate = tf.compat.v1.placeholder(tf.float32, shape=[])
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(self.update_ops): with tf.control_dependencies(self.update_ops):
self.optimizer = tf.train.RMSPropOptimizer(self.learningRate).minimize(self.loss) self.optimizer = tf.compat.v1.train.RMSPropOptimizer(self.learningRate).minimize(self.loss)
# initialize TF # initialize TF
(self.sess, self.saver) = self.setupTF() (self.sess, self.saver) = self.setupTF()
...@@ -64,11 +66,11 @@ class Model: ...@@ -64,11 +66,11 @@ class Model:
# create layers # create layers
pool = cnnIn4d # input to first CNN layer pool = cnnIn4d # input to first CNN layer
for i in range(numLayers): for i in range(numLayers):
kernel = tf.Variable(tf.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1)) kernel = tf.Variable(tf.random.truncated_normal([kernelVals[i], kernelVals[i], featureVals[i], featureVals[i + 1]], stddev=0.1))
conv = tf.nn.conv2d(pool, kernel, padding='SAME', strides=(1,1,1,1)) conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1,1,1,1))
conv_norm = tf.layers.batch_normalization(conv, training=self.is_train) conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train)
relu = tf.nn.relu(conv_norm) relu = tf.nn.relu(conv_norm)
pool = tf.nn.max_pool(relu, (1, poolVals[i][0], poolVals[i][1], 1), (1, strideVals[i][0], strideVals[i][1], 1), 'VALID') pool = tf.nn.max_pool2d(input=relu, ksize=(1, poolVals[i][0], poolVals[i][1], 1), strides=(1, strideVals[i][0], strideVals[i][1], 1), padding='VALID')
self.cnnOut4d = pool self.cnnOut4d = pool
...@@ -79,43 +81,43 @@ class Model: ...@@ -79,43 +81,43 @@ class Model:
# basic cells which is used to build RNN # basic cells which is used to build RNN
numHidden = 256 numHidden = 256
cells = [tf.contrib.rnn.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in range(2)] # 2 layers cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=numHidden, state_is_tuple=True) for _ in range(2)] # 2 layers
# stack basic cells # stack basic cells
stacked = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True) stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
# bidirectional RNN # bidirectional RNN
# BxTxF -> BxTx2H # BxTxF -> BxTx2H
((fw, bw), _) = tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype) ((fw, bw), _) = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d, dtype=rnnIn3d.dtype)
# BxTxH + BxTxH -> BxTx2H -> BxTx1X2H # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)
# project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
kernel = tf.Variable(tf.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1)) kernel = tf.Variable(tf.random.truncated_normal([1, 1, numHidden * 2, len(self.charList) + 1], stddev=0.1))
self.rnnOut3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2]) self.rnnOut3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), axis=[2])
def setupCTC(self): def setupCTC(self):
"create CTC loss and decoder and return them" "create CTC loss and decoder and return them"
# BxTxC -> TxBxC # BxTxC -> TxBxC
self.ctcIn3dTBC = tf.transpose(self.rnnOut3d, [1, 0, 2]) self.ctcIn3dTBC = tf.transpose(a=self.rnnOut3d, perm=[1, 0, 2])
# ground truth text as sparse tensor # ground truth text as sparse tensor
self.gtTexts = tf.SparseTensor(tf.placeholder(tf.int64, shape=[None, 2]) , tf.placeholder(tf.int32, [None]), tf.placeholder(tf.int64, [2])) self.gtTexts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]) , tf.compat.v1.placeholder(tf.int32, [None]), tf.compat.v1.placeholder(tf.int64, [2]))
# calc loss for batch # calc loss for batch
self.seqLen = tf.placeholder(tf.int32, [None]) self.seqLen = tf.compat.v1.placeholder(tf.int32, [None])
self.loss = tf.reduce_mean(tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True)) self.loss = tf.reduce_mean(input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, ctc_merge_repeated=True))
# calc loss for each element to compute label probability # calc loss for each element to compute label probability
self.savedCtcInput = tf.placeholder(tf.float32, shape=[Model.maxTextLen, None, len(self.charList) + 1]) self.savedCtcInput = tf.compat.v1.placeholder(tf.float32, shape=[Model.maxTextLen, None, len(self.charList) + 1])
self.lossPerElement = tf.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, sequence_length=self.seqLen, ctc_merge_repeated=True) self.lossPerElement = tf.compat.v1.nn.ctc_loss(labels=self.gtTexts, inputs=self.savedCtcInput, sequence_length=self.seqLen, ctc_merge_repeated=True)
# decoder: either best path decoding or beam search decoding # decoder: either best path decoding or beam search decoding
if self.decoderType == DecoderType.BestPath: if self.decoderType == DecoderType.BestPath:
self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen) self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen)
elif self.decoderType == DecoderType.BeamSearch: elif self.decoderType == DecoderType.BeamSearch:
self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50, merge_repeated=False) self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctcIn3dTBC, sequence_length=self.seqLen, beam_width=50)
elif self.decoderType == DecoderType.WordBeamSearch: elif self.decoderType == DecoderType.WordBeamSearch:
# import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch) # import compiled word beam search operation (see https://github.com/githubharald/CTCWordBeamSearch)
word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so') word_beam_search_module = tf.load_op_library('TFWordBeamSearch.so')
...@@ -126,7 +128,7 @@ class Model: ...@@ -126,7 +128,7 @@ class Model:
corpus = open('../data/corpus.txt').read() corpus = open('../data/corpus.txt').read()
# decode using the "Words" mode of word beam search # decode using the "Words" mode of word beam search
self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(self.ctcIn3dTBC, dim=2), 50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8')) self.decoder = word_beam_search_module.word_beam_search(tf.nn.softmax(self.ctcIn3dTBC, axis=2), 50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), wordChars.encode('utf8'))
def setupTF(self): def setupTF(self):
...@@ -134,9 +136,9 @@ class Model: ...@@ -134,9 +136,9 @@ class Model:
print('Python: '+sys.version) print('Python: '+sys.version)
print('Tensorflow: '+tf.__version__) print('Tensorflow: '+tf.__version__)
sess=tf.Session() # TF session sess=tf.compat.v1.Session() # TF session
saver = tf.train.Saver(max_to_keep=1) # saver saves model to file saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file
modelDir = '../model/' modelDir = '../model/'
latestSnapshot = tf.train.latest_checkpoint(modelDir) # is there a saved model? latestSnapshot = tf.train.latest_checkpoint(modelDir) # is there a saved model?
...@@ -150,7 +152,7 @@ class Model: ...@@ -150,7 +152,7 @@ class Model:
saver.restore(sess, latestSnapshot) saver.restore(sess, latestSnapshot)
else: else:
print('Init with new values') print('Init with new values')
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
return (sess,saver) return (sess,saver)
...@@ -258,18 +260,17 @@ class Model: ...@@ -258,18 +260,17 @@ class Model:
ctcInput = evalRes[1] ctcInput = evalRes[1]
evalList = self.lossPerElement evalList = self.lossPerElement
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False} feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
lossVals = self.sess.run(evalList, feedDict) #lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals) #probs = np.exp(-lossVals)
# dump the output of the NN to CSV file(s) # dump the output of the NN to CSV file(s)
if self.dump: if self.dump:
self.dumpNNOutput(evalRes[1]) self.dumpNNOutput(evalRes[1])
return (texts, probs) return (texts)
def save(self): def save(self):
"save model to file" "save model to file"
self.snapID += 1 self.snapID += 1
self.saver.save(self.sess, '../model/snapshot', global_step=self.snapID) self.saver.save(self.sess, '../model/snapshot', global_step=self.snapID)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment