Skip to content
Snippets Groups Projects
Commit 9441751e authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Improved naming for trained model.

parent 3ab204e3
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ TODO: Fix parameter bug -> Done
TODO: Improve naming of training data
TODO: Rename 'Artificial_Neural_Network' -> Done
TODO: Rename 'ANN_Training_Data_Generator' -> Done
TODO: Improve naming for trained model -> Done
"""
import numpy as np
......@@ -160,9 +161,10 @@ class ModelTrainer(object):
def save_model(self):
# Saving Model
path = self._model.get_name() + '__' + self._optimizer.__class__.__name__ + '_'\
+ str(self._learning_rate) + '__' + self._loss_function.__class__.__name__ + '__' + self._training_file\
+ '__' + self._validation_file + '.pt'
train_name = self._training_file.split('.npy')[0]
valid_name = self._validation_file.split('.npy')[0]
path = self._model.get_name() + '__' + self._optimizer.__class__.__name__ + '_' + str(self._learning_rate)\
+ '__' + self._loss_function.__class__.__name__ + '__' + train_name + '__' + valid_name + '.pt'
torch.save(self._model.state_dict(), self._model_dir + '/Model__' + path)
torch.save(self._validation_loss, self._model_dir + '/Loss__' + path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment