diff --git a/ANN_Training.py b/ANN_Training.py index c22dbe88cbac53a8c3b2c0e35ea02d91feb846e9..67f11cbbd4ba8c61806c442a1ed38b2274d6941a 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -5,13 +5,14 @@ Code-Style: E226, W503 Docstring-Style: D200, D400 -TODO: Add documentation +TODO: Add documentation -> Done TODO: Add README for ANN training TODO: Fix random seed TODO: Write-protect all data and models TODO: Put legend outside plot (bbox_to_anchor) TODO: Put plotting into separate function TODO: Reduce number of testing epochs to 50 +TODO: Adapt docstring to uniform standard """ import numpy as np @@ -32,10 +33,51 @@ matplotlib.use('Agg') class ModelTrainer(object): - def __init__(self, config): + """Class for ANN model training. + + Trains and tests a model with set loss function and optimizer. + + Attributes + ---------- + model : torch.nn.Module + ANN model instance for evaluation. + loss_function : torch.nn.modules.loss + Function to evaluate loss during model training. + optimizer : torch.optim + Optimizer for model training. + validation_loss : torch.Tensor + List of validation loss values during training. + + Methods + ------- + epoch_training() + Trains model for a given number of epochs. + test_model() + Evaluates predictions of a model. + save_model() + Saves state and validation loss of a model. + + """ + def __init__(self, config: dict) -> None: + """Initializes ModelTrainer. + + Parameters + ---------- + config : dict + Additional parameters for model trainer. + + """ self._reset(config) - def _reset(self, config): + def _reset(self, config: dict) -> None: + """Resets instance variables. + + Parameters + ---------- + config : dict + Additional parameters for model trainer. + + """ self._batch_size = config.pop('batch_size', 500) self._num_epochs = config.pop('num_epochs', 1000) self._threshold = config.pop('threshold', 1e-5) @@ -64,8 +106,27 @@ class ModelTrainer(object): self._optimizer = getattr(torch.optim, optimizer)( self._model.parameters(), **optimizer_config) self._validation_loss = torch.zeros(self._num_epochs//10) - - def epoch_training(self, dataset, num_epochs=None, verbose=True): + print(type(self._model), type(self._loss_function), type(self._optimizer), + type(self._validation_loss)) + + def epoch_training(self, dataset: torch.utils.data.dataset.TensorDataset, + num_epochs: int = None, verbose: bool = True) -> None: + """Trains model for a given number of epochs. + + Trains model and saves the validation loss. The training stops after the given number of + epochs or if the threshold is reached. + + Parameters + ---------- + dataset : torch.utils.data.dataset.TensorDataset + Training dataset. + num_epochs : int, optional + Number of epochs for training. If None, set to instance value. Default: None. + verbose : bool, optional + Flag whether commentary in console is wanted. Default: False. + + """ + print(type(dataset)) tic = time.perf_counter() if num_epochs is None: num_epochs = self._num_epochs @@ -117,7 +178,26 @@ class ModelTrainer(object): if verbose: print(f'Total runtime: {toc-tic:0.4f}s\n') - def test_model(self, training_set, test_set): + def test_model(self, training_set: torch.utils.data.dataset.TensorDataset, + test_set: torch.utils.data.dataset.TensorDataset) -> dict: + """Evaluates predictions of a model. + + Trains a model and compares the predicted and true results by evaluating precision, recall, + and f-score for both classes, as well as accuracy and AUROC score. + + Parameters + ---------- + training_set : torch.utils.data.dataset.TensorDataset + Training dataset. + test_set : torch.utils.data.dataset.TensorDataset + Test dataset. + + Returns + ------- + dict + Dictionary containing classification evaluation data. + + """ self.epoch_training(training_set, num_epochs=100, verbose=False) self._model.eval() @@ -137,7 +217,17 @@ class ModelTrainer(object): 'F-Score_Smooth': f_score[0], 'F-Score_Troubled': f_score[1], 'Accuracy': accuracy, 'AUROC': auroc} - def save_model(self, directory, model_name='test_model'): + def save_model(self, directory: str, model_name: str = 'test_model') -> None: + """Saves state and validation loss of a model. + + Parameters + ---------- + directory: str + Path to directory in which model is saved. + model_name: str, optional + Name of model for saving. Default: 'test_model'. + + """ # Set paths for files if not existing already model_dir = directory + '/trained models' if not os.path.exists(model_dir): @@ -148,15 +238,47 @@ class ModelTrainer(object): torch.save(self._validation_loss, model_dir + '/loss__' + model_name + '.pt') -def read_training_data(directory, normalized=True): +def read_training_data(directory: str, + normalized: bool = True) -> torch.utils.data.dataset.TensorDataset: + """Reads training data from directory. + + Parameters + ---------- + directory: str + Path to directory in which training data is saved. + normalized: bool, optional + Flag whether normalized data should be used. Default: True. + + Returns + ------- + torch.utils.data.dataset.TensorDataset + Training dataset. + + """ # Get training dataset from saved file and map to Torch tensor and dataset input_file = directory + ('/normalized_input_data.npy' if normalized else '/input_data.npy') output_file = directory + '/output_data.npy' return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file)))) -def evaluate_models(models, directory, num_iterations=100, colors=None, - compare_normalization=False): +def evaluate_models(models: dict, directory: str, num_iterations: int = 100, colors: dict = None, + compare_normalization: bool = False) -> None: + """Evaluates the classification of a given set of models. + + Parameters + ---------- + models: dict + Dictionary of models to evaluate. + directory: str + Path to directory for saving resulting plots. + num_iterations: int, optional + Number of iterations for evaluation. Default: 100. + colors: dict, optional + Dictionary containing plotting colors. If None, set to default colors. Default: None. + compare_normalization: bool, optional + Flag whether both normalized and raw data should be evaluated. Default: False. + + """ tic = time.perf_counter() if colors is None: colors = {'Accuracy': 'magenta', 'Precision_Smooth': 'red',