Skip to content
Snippets Groups Projects
Select Git revision
1 result Searching

introduction.tex

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ANN_Training.py 12.46 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    Code-Style: E226, W503
    Docstring-Style: D200, D400
    
    TODO: Add README for ANN training
    
    """
    import numpy as np
    import time
    import os
    import torch
    import json
    from torch.utils.data import TensorDataset, DataLoader, random_split
    from sklearn.model_selection import KFold
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support, \
        roc_auc_score
    
    import ANN_Model
    
    
    class ModelTrainer:
        """Class for ANN model training.
    
        Trains and tests a model with set loss function and optimizer.
    
        Attributes
        ----------
        model : torch.nn.Module
            ANN model instance for evaluation.
        loss_function : torch.nn.modules.loss
            Function to evaluate loss during model training.
        optimizer : torch.optim
            Optimizer for model training.
        validation_loss : torch.Tensor
            List of validation loss values during training.
    
        Methods
        -------
        epoch_training(dataset, num_epochs, verbose)
            Trains model for a given number of epochs.
        test_model(training_set, test_set)
            Evaluates predictions of a model.
        save_model(directory, model_name)
            Saves state and validation loss of a model.
    
        """
        def __init__(self, config: dict) -> None:
            """Initializes ModelTrainer.
    
            Parameters
            ----------
            config : dict
                Additional parameters for model trainer.
    
            """
            self._reset(config)
    
        def _reset(self, config: dict) -> None:
            """Resets instance variables.
    
            Parameters
            ----------
            config : dict
                Additional parameters for model trainer.
    
            """
            self._batch_size = config.pop('batch_size', 500)
            self._num_epochs = config.pop('num_epochs', 1000)
            self._threshold = config.pop('threshold', 1e-5)
    
            model = config.pop('model', 'ThreeLayerReLu')
            model_config = config.pop('model_config', {})
            loss_function = config.pop('loss_function', 'BCELoss')
            loss_config = config.pop('loss_config', {})
            optimizer = config.pop('optimizer', 'Adam')
            optimizer_config = config.pop('optimizer_config', {})
    
            # Set learning rate
            learning_rate = config.pop('learning_rate', 1e-2)
            optimizer_config['lr'] = learning_rate
    
            if not hasattr(ANN_Model, model):
                raise ValueError('Invalid model: "%s"' % model)
            if not hasattr(torch.nn.modules.loss, loss_function):
                raise ValueError('Invalid loss function: "%s"' % loss_function)
            if not hasattr(torch.optim, optimizer):
                raise ValueError('Invalid optimizer: "%s"' % optimizer)
    
            self._model = getattr(ANN_Model, model)(model_config)
            self._loss_function = getattr(torch.nn.modules.loss, loss_function)(
                **loss_config)
            self._optimizer = getattr(torch.optim, optimizer)(
                self._model.parameters(), **optimizer_config)
            self._validation_loss = torch.zeros(self._num_epochs//10)
    
        def epoch_training(self, dataset: torch.utils.data.dataset.TensorDataset,
                           num_epochs: int = None, verbose: bool = True) -> None:
            """Trains model for a given number of epochs.
    
            Trains model and saves the validation loss. The training stops after
            the given number of epochs or if the threshold is reached.
    
            Parameters
            ----------
            dataset : torch.utils.data.dataset.TensorDataset
                Training dataset.
            num_epochs : int, optional
                Number of epochs for training.
                Default: None (i.e. instance variable).
            verbose : bool, optional
                Flag whether commentary in console is wanted. Default: False.
    
            """
            tic = time.perf_counter()
            if num_epochs is None:
                num_epochs = self._num_epochs
    
            # Split data into training and validation set
            num_samples = len(dataset)
            if verbose:
                print('Splitting data randomly into training and validation set.')
            train_ds, valid_ds = random_split(dataset, [round(num_samples*0.8),
                                                        round(num_samples*0.2)])
    
            # Load sets
            train_dl = DataLoader(train_ds, batch_size=self._batch_size,
                                  shuffle=True)
            valid_dl = DataLoader(valid_ds, batch_size=self._batch_size * 2)
    
            # Train with validation
            if verbose:
                print('\nTraining model...')
                print('Number of epochs:', num_epochs)
            tic_train = time.perf_counter()
            for epoch in range(num_epochs):
                self._model.train()
                for x_batch, y_batch in train_dl:
                    pred = self._model(x_batch.float())
                    loss = self._loss_function(pred, y_batch.float()).mean()
    
                    # Run back propagation, update the weights,
                    # and zero gradients for next epoch
                    loss.backward()
                    self._optimizer.step()
                    self._optimizer.zero_grad()
    
                # Determine validation loss
                self._model.eval()
                with torch.no_grad():
                    valid_loss = sum(
                        self._loss_function(self._model(x_batch_valid.float()),
                                            y_batch_valid.float())
                        for x_batch_valid, y_batch_valid in valid_dl)
    
                    # Report validation loss
                    if (epoch+1) % 100 == 0:
                        self._validation_loss[int((epoch+1) / 100)-1] \
                            = valid_loss / len(valid_dl)
                        if verbose:
                            print(epoch+1, 'epochs completed. Loss:',
                                  valid_loss / len(valid_dl))
    
                    # Interrupt if threshold is reached
                    if valid_loss / len(valid_dl) < self._threshold:
                        break
            toc_train = time.perf_counter()
            if verbose:
                print('Finished training model!')
                print(f'Training time: {toc_train-tic_train:0.4f}s\n')
            toc = time.perf_counter()
            if verbose:
                print(f'Total runtime: {toc-tic:0.4f}s\n')
    
        def test_model(self, training_set: torch.utils.data.dataset.TensorDataset,
                       test_set: torch.utils.data.dataset.TensorDataset) -> dict:
            """Evaluates predictions of a model.
    
            Trains a model and compares the predicted and true results by
            evaluating precision, recall, and f-score for both classes,
            as well as accuracy and AUROC score.
    
            Parameters
            ----------
            training_set : torch.utils.data.dataset.TensorDataset
                Training dataset.
            test_set : torch.utils.data.dataset.TensorDataset
                Test dataset.
    
            Returns
            -------
            dict
                Dictionary containing classification evaluation data.
    
            """
            # Train model
            self.epoch_training(training_set, num_epochs=50, verbose=False)
            self._model.eval()
    
            # Classify data
            x_test, y_test = test_set
            model_score = self._model(x_test.float())
            model_output = torch.argmax(model_score, dim=1)
    
            # Evaluate classification
            y_true = y_test.detach().numpy()[:, 1]
            y_pred = model_output.detach().numpy()
            accuracy = accuracy_score(y_true, y_pred)
            precision, recall, f_score, support = precision_recall_fscore_support(
                y_true, y_pred, zero_division=0)
            auroc = roc_auc_score(y_true, y_pred)
    
            return {'Precision_Smooth': precision[0],
                    'Precision_Troubled': precision[1],
                    'Recall_Smooth': recall[0],
                    'Recall_Troubled': recall[1],
                    'F-Score_Smooth': f_score[0],
                    'F-Score_Troubled': f_score[1],
                    'Accuracy': accuracy,
                    'AUROC': auroc}
    
        def save_model(self, directory: str,
                       model_name: str = 'test_model') -> None:
            """Saves state and validation loss of a model.
    
            Parameters
            ----------
            directory : str
               Path to directory in which model is saved.
            model_name : str, optional
                Name of model for saving. Default: 'test_model'.
    
            """
            # Set paths for files if not existing already
            model_dir = directory + '/trained models'
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
    
            # Save model and loss
            torch.save(self._model.state_dict(), model_dir + '/' +
                       model_name + '.model.pt')
            torch.save(self._validation_loss, model_dir + '/' +
                       model_name + '.loss.pt')
    
    
    def read_training_data(directory: str, normalized: bool = True) -> \
            torch.utils.data.dataset.TensorDataset:
        """Reads training data from directory.
    
        Parameters
        ----------
        directory : str
            Path to directory in which training data is saved.
        normalized : bool, optional
            Flag whether normalized data should be used. Default: True.
    
        Returns
        -------
        torch.utils.data.dataset.TensorDataset
            Training dataset.
    
        """
        # Get training dataset from saved file and map to Torch tensor and dataset
        input_file = directory + ('/input_data.normalized.npy'
                                  if normalized else '/input_data.raw.npy')
        output_file = directory + '/output_data.npy'
        return TensorDataset(*map(torch.tensor, (np.load(input_file),
                                                 np.load(output_file))))
    
    
    def evaluate_models(models: dict, directory: str, num_iterations: int = 100,
                        compare_normalization: bool = False) -> None:
        """Evaluates the classification of a given set of models.
    
        Evaluates the classification and saves the results in a JSON file.
    
        Parameters
        ----------
        models : dict
            Dictionary of models to evaluate.
        directory : str
            Path to directory for saving resulting plots.
        num_iterations : int, optional
            Number of iterations for evaluation. Default: 100.
        compare_normalization : bool, optional
            Flag whether both normalized and raw data should be evaluated.
            Default: False.
    
        """
        tic = time.perf_counter()
    
        # Read training data
        print('Read normalized training data.')
        datasets = {'normalized': read_training_data(directory)}
        if compare_normalization:
            print('Read raw, non-normalized training data.')
            datasets['raw'] = read_training_data(directory, False)
    
        # Train models for evaluation
        print('\nTraining models with 5-fold cross validation...')
        print('Number of iterations:', num_iterations)
        tic_train = time.perf_counter()
        classification_stats = {}
        for iteration in range(num_iterations):
            # Split data for cross validation
            for train_index, test_index in KFold(
                    n_splits=5, shuffle=True).split(datasets['normalized']):
                for dataset in datasets.keys():
                    training_set = TensorDataset(*datasets[dataset][train_index])
                    test_set = datasets[dataset][test_index]
    
                    # Save results for each model on split dataset
                    for model in models:
                        result = models[model].test_model(training_set, test_set)
                        for measure in result.keys():
                            if measure not in classification_stats.keys():
                                classification_stats[measure] = \
                                    {model + ' (' + dataset + ')': []
                                     for model in models
                                     for dataset in datasets}
                            classification_stats[measure][model + ' (' + dataset +
                                                          ')'].append(
                                result[measure])
            # Report status
            if iteration+1 % max(10, 10*(num_iterations//100)):
                print(iteration+1, 'iterations completed.')
        toc_train = time.perf_counter()
        print('Finished training models with 5-fold cross validation!')
        print(f'Training time: {toc_train - tic_train:0.4f}s\n')
    
        # Set paths for plot files if not existing already
        plot_dir = directory + '/model evaluation'
        if not os.path.exists(plot_dir):
            os.makedirs(plot_dir)
    
        # Save evaluation results in JSON format
        print('Saving evaluation results in JSON format.')
        with open(plot_dir + '/' + '_'.join(models.keys()) + '.json', 'w')\
                as json_file:
            json_file.write(json.dumps(classification_stats))
        toc = time.perf_counter()
        print(f'Total runtime: {toc - tic:0.4f}s')