Skip to content
Snippets Groups Projects
Commit 10781ce4 authored by Harald Scheidl's avatar Harald Scheidl
Browse files

val set: always use all samples, even if last batch is smaller than specified batch size

parent 339cceed
No related branches found
No related tags found
No related merge requests found
......@@ -107,24 +107,34 @@ class DataLoaderIAM:
self.currIdx = 0
random.shuffle(self.trainSamples)
self.samples = self.trainSamples[:self.numTrainSamplesPerEpoch]
self.currSet = 'train'
def validationSet(self):
"switch to validation set"
self.dataAugmentation = False
self.currIdx = 0
self.samples = self.validationSamples
self.currSet = 'val'
def getIteratorInfo(self):
"current batch index and overall number of batches"
return (self.currIdx // self.batchSize + 1, len(self.samples) // self.batchSize)
if self.currSet == 'train':
numBatches = int(np.floor(len(self.samples) / self.batchSize)) # train set: only full-sized batches
else:
numBatches = int(np.ceil(len(self.samples) / self.batchSize)) # val set: allow last batch to be smaller
currBatch = self.currIdx // self.batchSize + 1
return currBatch, numBatches
def hasNext(self):
"iterator"
return self.currIdx + self.batchSize <= len(self.samples)
if self.currSet == 'train':
return self.currIdx + self.batchSize <= len(self.samples) # train set: only full-sized batches
else:
return self.currIdx < len(self.samples) # val set: allow last batch to be smaller
def getNext(self):
"iterator"
batchRange = range(self.currIdx, self.currIdx + self.batchSize)
batchRange = range(self.currIdx, min(self.currIdx + self.batchSize, len(self.samples)))
gtTexts = [self.samples[i].gtText for i in batchRange]
imgs = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment