diff --git a/src/main.py b/src/main.py index 420461e588d71d998c6c1fc6fa7967c7fb10f9f0..1ee11d29dc923b7958ed9ed7cbb80a527c74019a 100644 --- a/src/main.py +++ b/src/main.py @@ -30,10 +30,10 @@ def get_img_size(line_mode: bool = False) -> Tuple[int, int]: return 128, get_img_height() -def write_summary(char_error_rates: List[float], word_accuracies: List[float]) -> None: +def write_summary(average_train_loss: List[float], char_error_rates: List[float], word_accuracies: List[float]) -> None: """Writes training summary file for NN.""" with open(FilePaths.fn_summary, 'w') as f: - json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) + json.dump({'averageTrainLoss': average_train_loss, 'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) def char_list_from_file() -> List[str]: @@ -49,6 +49,10 @@ def train(model: Model, epoch = 0 # number of training epochs since start summary_char_error_rates = [] summary_word_accuracies = [] + + train_loss_in_epoch = [] + average_train_loss = [] + preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode) best_char_error_rate = float('inf') # best validation character error rate no_improvement_since = 0 # number of epochs no improvement of character error rate occurred @@ -66,6 +70,7 @@ def train(model: Model, batch = preprocessor.process_batch(batch) loss = model.train_batch(batch) print(f'Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}') + train_loss_in_epoch.append(loss) # validate char_error_rate, word_accuracy = validate(model, loader, line_mode) @@ -73,7 +78,11 @@ def train(model: Model, # write summary summary_char_error_rates.append(char_error_rate) summary_word_accuracies.append(word_accuracy) - write_summary(summary_char_error_rates, summary_word_accuracies) + average_train_loss.append((sum(train_loss_in_epoch)) / len(train_loss_in_epoch)) + write_summary(average_train_loss, summary_char_error_rates, summary_word_accuracies) + + # reset train loss list + train_loss_in_epoch = [] # if best validation accuracy so far, save model parameters if char_error_rate < best_char_error_rate: @@ -82,12 +91,12 @@ def train(model: Model, no_improvement_since = 0 model.save() else: - print(f'Character error rate not improved, best so far: {char_error_rate * 100.0}%') + print(f'Character error rate not improved, best so far: {best_char_error_rate * 100.0}%') no_improvement_since += 1 # stop training if no more improvement in the last x epochs if no_improvement_since >= early_stopping: - print(f'No more improvement since {early_stopping} epochs. Training stopped.') + print(f'No more improvement for {early_stopping} epochs. Training stopped.') break