Skip to content
Snippets Groups Projects
Commit bea05f67 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added documentation to 'ANN_Training'.

parent 1d1344ca
No related branches found
No related tags found
No related merge requests found
......@@ -5,13 +5,14 @@
Code-Style: E226, W503
Docstring-Style: D200, D400
TODO: Add documentation
TODO: Add documentation -> Done
TODO: Add README for ANN training
TODO: Fix random seed
TODO: Write-protect all data and models
TODO: Put legend outside plot (bbox_to_anchor)
TODO: Put plotting into separate function
TODO: Reduce number of testing epochs to 50
TODO: Adapt docstring to uniform standard
"""
import numpy as np
......@@ -32,10 +33,51 @@ matplotlib.use('Agg')
class ModelTrainer(object):
def __init__(self, config):
"""Class for ANN model training.
Trains and tests a model with set loss function and optimizer.
Attributes
----------
model : torch.nn.Module
ANN model instance for evaluation.
loss_function : torch.nn.modules.loss
Function to evaluate loss during model training.
optimizer : torch.optim
Optimizer for model training.
validation_loss : torch.Tensor
List of validation loss values during training.
Methods
-------
epoch_training()
Trains model for a given number of epochs.
test_model()
Evaluates predictions of a model.
save_model()
Saves state and validation loss of a model.
"""
def __init__(self, config: dict) -> None:
"""Initializes ModelTrainer.
Parameters
----------
config : dict
Additional parameters for model trainer.
"""
self._reset(config)
def _reset(self, config):
def _reset(self, config: dict) -> None:
"""Resets instance variables.
Parameters
----------
config : dict
Additional parameters for model trainer.
"""
self._batch_size = config.pop('batch_size', 500)
self._num_epochs = config.pop('num_epochs', 1000)
self._threshold = config.pop('threshold', 1e-5)
......@@ -64,8 +106,27 @@ class ModelTrainer(object):
self._optimizer = getattr(torch.optim, optimizer)(
self._model.parameters(), **optimizer_config)
self._validation_loss = torch.zeros(self._num_epochs//10)
print(type(self._model), type(self._loss_function), type(self._optimizer),
type(self._validation_loss))
def epoch_training(self, dataset: torch.utils.data.dataset.TensorDataset,
num_epochs: int = None, verbose: bool = True) -> None:
"""Trains model for a given number of epochs.
Trains model and saves the validation loss. The training stops after the given number of
epochs or if the threshold is reached.
Parameters
----------
dataset : torch.utils.data.dataset.TensorDataset
Training dataset.
num_epochs : int, optional
Number of epochs for training. If None, set to instance value. Default: None.
verbose : bool, optional
Flag whether commentary in console is wanted. Default: False.
def epoch_training(self, dataset, num_epochs=None, verbose=True):
"""
print(type(dataset))
tic = time.perf_counter()
if num_epochs is None:
num_epochs = self._num_epochs
......@@ -117,7 +178,26 @@ class ModelTrainer(object):
if verbose:
print(f'Total runtime: {toc-tic:0.4f}s\n')
def test_model(self, training_set, test_set):
def test_model(self, training_set: torch.utils.data.dataset.TensorDataset,
test_set: torch.utils.data.dataset.TensorDataset) -> dict:
"""Evaluates predictions of a model.
Trains a model and compares the predicted and true results by evaluating precision, recall,
and f-score for both classes, as well as accuracy and AUROC score.
Parameters
----------
training_set : torch.utils.data.dataset.TensorDataset
Training dataset.
test_set : torch.utils.data.dataset.TensorDataset
Test dataset.
Returns
-------
dict
Dictionary containing classification evaluation data.
"""
self.epoch_training(training_set, num_epochs=100, verbose=False)
self._model.eval()
......@@ -137,7 +217,17 @@ class ModelTrainer(object):
'F-Score_Smooth': f_score[0], 'F-Score_Troubled': f_score[1],
'Accuracy': accuracy, 'AUROC': auroc}
def save_model(self, directory, model_name='test_model'):
def save_model(self, directory: str, model_name: str = 'test_model') -> None:
"""Saves state and validation loss of a model.
Parameters
----------
directory: str
Path to directory in which model is saved.
model_name: str, optional
Name of model for saving. Default: 'test_model'.
"""
# Set paths for files if not existing already
model_dir = directory + '/trained models'
if not os.path.exists(model_dir):
......@@ -148,15 +238,47 @@ class ModelTrainer(object):
torch.save(self._validation_loss, model_dir + '/loss__' + model_name + '.pt')
def read_training_data(directory, normalized=True):
def read_training_data(directory: str,
normalized: bool = True) -> torch.utils.data.dataset.TensorDataset:
"""Reads training data from directory.
Parameters
----------
directory: str
Path to directory in which training data is saved.
normalized: bool, optional
Flag whether normalized data should be used. Default: True.
Returns
-------
torch.utils.data.dataset.TensorDataset
Training dataset.
"""
# Get training dataset from saved file and map to Torch tensor and dataset
input_file = directory + ('/normalized_input_data.npy' if normalized else '/input_data.npy')
output_file = directory + '/output_data.npy'
return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file))))
def evaluate_models(models, directory, num_iterations=100, colors=None,
compare_normalization=False):
def evaluate_models(models: dict, directory: str, num_iterations: int = 100, colors: dict = None,
compare_normalization: bool = False) -> None:
"""Evaluates the classification of a given set of models.
Parameters
----------
models: dict
Dictionary of models to evaluate.
directory: str
Path to directory for saving resulting plots.
num_iterations: int, optional
Number of iterations for evaluation. Default: 100.
colors: dict, optional
Dictionary containing plotting colors. If None, set to default colors. Default: None.
compare_normalization: bool, optional
Flag whether both normalized and raw data should be evaluated. Default: False.
"""
tic = time.perf_counter()
if colors is None:
colors = {'Accuracy': 'magenta', 'Precision_Smooth': 'red',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment