Skip to content
Snippets Groups Projects
Commit 2b2a90a2 authored by Alexander Alexeev's avatar Alexander Alexeev
Browse files

Fixed Bugs and Added Train Loss Dump Feature

parent 2c14ec8e
Branches
No related tags found
No related merge requests found
......@@ -30,10 +30,10 @@ def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
return 128, get_img_height()
def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -> None:
def write_summary(average_train_loss: List[float], char_error_rates: List[float], word_accuracies: List[float]) -> None:
"""Writes training summary file for NN."""
with open(FilePaths.fn_summary, 'w') as f:
json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
json.dump({'averageTrainLoss': average_train_loss, 'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
def char_list_from_file() -> List[str]:
......@@ -49,6 +49,10 @@ def train(model: Model,
epoch = 0 # number of training epochs since start
summary_char_error_rates = []
summary_word_accuracies = []
train_loss_in_epoch = []
average_train_loss = []
preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode)
best_char_error_rate = float('inf') # best validation character error rate
no_improvement_since = 0 # number of epochs no improvement of character error rate occurred
......@@ -66,6 +70,7 @@ def train(model: Model,
batch = preprocessor.process_batch(batch)
loss = model.train_batch(batch)
print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}')
train_loss_in_epoch.append(loss)
# validate
char_error_rate, word_accuracy = validate(model, loader, line_mode)
......@@ -73,7 +78,11 @@ def train(model: Model,
# write summary
summary_char_error_rates.append(char_error_rate)
summary_word_accuracies.append(word_accuracy)
write_summary(summary_char_error_rates, summary_word_accuracies)
average_train_loss.append((sum(train_loss_in_epoch)) / len(train_loss_in_epoch))
write_summary(average_train_loss, summary_char_error_rates, summary_word_accuracies)
# reset train loss list
train_loss_in_epoch = []
# if best validation accuracy so far, save model parameters
if char_error_rate < best_char_error_rate:
......@@ -82,12 +91,12 @@ def train(model: Model,
no_improvement_since = 0
model.save()
else:
print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%')
print(f'Character error rate not improved, best so far: {best_char_error_rate * 100.0}%')
no_improvement_since += 1
# stop training if no more improvement in the last x epochs
if no_improvement_since >= early_stopping:
print(f'No more improvement since {early_stopping} epochs. Training stopped.')
print(f'No more improvement for {early_stopping} epochs. Training stopped.')
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment