Skip to content
Snippets Groups Projects
Commit 2b2a90a2 authored by Alexander Alexeev's avatar Alexander Alexeev
Browse files

Fixed Bugs and Added Train Loss Dump Feature

parent 2c14ec8e
No related branches found
No related tags found
No related merge requests found
...@@ -30,10 +30,10 @@ def get_img_size(line_mode: bool = False) -> Tuple[int, int]: ...@@ -30,10 +30,10 @@ def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
return 128, get_img_height() return 128, get_img_height()
def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -> None: def write_summary(average_train_loss: List[float], char_error_rates: List[float], word_accuracies: List[float]) -> None:
"""Writes training summary file for NN.""" """Writes training summary file for NN."""
with open(FilePaths.fn_summary, 'w') as f: with open(FilePaths.fn_summary, 'w') as f:
json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) json.dump({'averageTrainLoss': average_train_loss, 'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f)
def char_list_from_file() -> List[str]: def char_list_from_file() -> List[str]:
...@@ -49,6 +49,10 @@ def train(model: Model, ...@@ -49,6 +49,10 @@ def train(model: Model,
epoch = 0 # number of training epochs since start epoch = 0 # number of training epochs since start
summary_char_error_rates = [] summary_char_error_rates = []
summary_word_accuracies = [] summary_word_accuracies = []
train_loss_in_epoch = []
average_train_loss = []
preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode) preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode)
best_char_error_rate = float('inf') # best validation character error rate best_char_error_rate = float('inf') # best validation character error rate
no_improvement_since = 0 # number of epochs no improvement of character error rate occurred no_improvement_since = 0 # number of epochs no improvement of character error rate occurred
...@@ -66,6 +70,7 @@ def train(model: Model, ...@@ -66,6 +70,7 @@ def train(model: Model,
batch = preprocessor.process_batch(batch) batch = preprocessor.process_batch(batch)
loss = model.train_batch(batch) loss = model.train_batch(batch)
print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}') print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}')
train_loss_in_epoch.append(loss)
# validate # validate
char_error_rate, word_accuracy = validate(model, loader, line_mode) char_error_rate, word_accuracy = validate(model, loader, line_mode)
...@@ -73,7 +78,11 @@ def train(model: Model, ...@@ -73,7 +78,11 @@ def train(model: Model,
# write summary # write summary
summary_char_error_rates.append(char_error_rate) summary_char_error_rates.append(char_error_rate)
summary_word_accuracies.append(word_accuracy) summary_word_accuracies.append(word_accuracy)
write_summary(summary_char_error_rates, summary_word_accuracies) average_train_loss.append((sum(train_loss_in_epoch)) / len(train_loss_in_epoch))
write_summary(average_train_loss, summary_char_error_rates, summary_word_accuracies)
# reset train loss list
train_loss_in_epoch = []
# if best validation accuracy so far, save model parameters # if best validation accuracy so far, save model parameters
if char_error_rate < best_char_error_rate: if char_error_rate < best_char_error_rate:
...@@ -82,12 +91,12 @@ def train(model: Model, ...@@ -82,12 +91,12 @@ def train(model: Model,
no_improvement_since = 0 no_improvement_since = 0
model.save() model.save()
else: else:
print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%') print(f'Character error rate not improved, best so far: {best_char_error_rate * 100.0}%')
no_improvement_since += 1 no_improvement_since += 1
# stop training if no more improvement in the last x epochs # stop training if no more improvement in the last x epochs
if no_improvement_since >= early_stopping: if no_improvement_since >= early_stopping:
print(f'No more improvement since {early_stopping} epochs. Training stopped.') print(f'No more improvement for {early_stopping} epochs. Training stopped.')
break break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment