Skip to content
Snippets Groups Projects
Select Git revision
2 results Searching

__init__.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ANN_Training.py 8.76 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Add log to pipeline
    TODO: Remove object set-up
    TODO: Optimize Snakefile-vs-config relation
    TODO: Improve maximum selection runtime -> Done
    TODO: Change model output to binary -> Do? (changes training when applied in ANN_Model)
    TODO: Adapt TCD file to new classification
    TODO: Add evaluation for all classes (recall, precision, fscore)
    TODO: Add documentation
    
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    import torch
    from torch.utils.data import TensorDataset, DataLoader, random_split
    from sklearn.model_selection import KFold
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
    
    import ANN_Model
    from Plotting import plot_classification_accuracy, plot_boxplot
    
    
    class ModelTrainer(object):
        def __init__(self, config):
            self._reset(config)
    
        def _reset(self, config):
            self._dir = config.pop('dir', 'test_data')
            self._model_name = config.pop('model_name', '0')
            self._training_data = read_training_data(self._dir)
    
            self._batch_size = config.pop('batch_size', min(len(self._training_data)//2, 500))
            self._num_epochs = config.pop('num_epochs', 1000)
            self._threshold = config.pop('threshold', 1e-5)
            self._model = config.pop('model', 'ThreeLayerReLu')
            self._model_config = config.pop('model_config', {})
            self._loss_function = config.pop('loss_function', 'BCELoss')
            self._loss_config = config.pop('loss_config', {})
            self._optimizer = config.pop('optimizer', 'Adam')
            self._optimizer_config = config.pop('optimizer_config', {})
    
            # Set learning rate
            self._learning_rate = config.pop('learning_rate', 1e-2)
            self._optimizer_config['lr'] = self._learning_rate
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            if not hasattr(torch.nn.modules.loss, self._loss_function):
                raise ValueError('Invalid loss function: "%s"' % self._loss_function)
            if not hasattr(torch.optim, self._optimizer):
                raise ValueError('Invalid optimizer: "%s"' % self._optimizer)
    
            self._model = getattr(ANN_Model, self._model)(self._model_config)
            self._loss_function = getattr(torch.nn.modules.loss, self._loss_function)(
                **self._loss_config)
            self._optimizer = getattr(torch.optim, self._optimizer)(
                self._model.parameters(), **self._optimizer_config)
            self._validation_loss = torch.zeros(self._num_epochs//10)
    
        def epoch_training(self, dataset=None, num_epochs=None):
            # Split data into training and validation set
            if dataset is None:
                dataset = self._training_data
            if num_epochs is None:
                num_epochs = self._num_epochs
            num_samples = len(dataset)
            train_ds, valid_ds = random_split(dataset, [round(num_samples*0.8), round(num_samples*0.2)])
    
            # Load sets
            train_dl = DataLoader(train_ds, batch_size=self._batch_size, shuffle=True)
            valid_dl = DataLoader(valid_ds, batch_size=self._batch_size * 2)
    
            # Training with Validation
            for epoch in range(num_epochs):
                self._model.train()
                for x_batch, y_batch in train_dl:
                    pred = self._model(x_batch.float())
                    loss = self._loss_function(pred, y_batch.float()).mean()
    
                    # Run back propagation, update the weights, and zero gradients for next epoch
                    loss.backward()
                    self._optimizer.step()
                    self._optimizer.zero_grad()
    
                self._model.eval()
                with torch.no_grad():
                    valid_loss = sum(
                        self._loss_function(self._model(x_batch_valid.float()), y_batch_valid.float())
                        for x_batch_valid, y_batch_valid in valid_dl)
    
                    if (epoch+1) % 100 == 0:
                        self._validation_loss[int((epoch+1) / 100)-1] = valid_loss / len(valid_dl)
                        print(epoch+1, valid_loss / len(valid_dl))
    
                    if valid_loss / len(valid_dl) < self._threshold:
                        break
    
        def test_model(self, training_set, test_set):
            self.epoch_training(training_set, num_epochs=100)
            self._model.eval()
    
            x_test, y_test = test_set
            # print(self._model(x_test.float()))
            model_score = self._model(x_test.float())
            # model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
            #                              for value in torch.argmax(model_score, dim=1)])
            # print(model_output)
            model_output = torch.argmax(model_score, dim=1)
            # print(model_output)
    
            y_true = y_test.detach().numpy()[:, 1]
            y_pred = model_output.detach().numpy()
            # y_score = model_score.detach().numpy()[:, 0]
            accuracy = accuracy_score(y_true, y_pred)
            # print('sklearn', accuracy)
            precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
            # print(precision, recall, f_score)
            # print()
            # auroc = roc_auc_score(y_true, y_score)
            # print('auroc raw', auroc)
            auroc = roc_auc_score(y_true, y_pred)
            print('auroc true', auroc)
            # fpr, tpr, thresholds = roc_curve(y_true, y_score)
            # roc = [tpr, fpr, thresholds]
            # print(roc)
            # plt.plot(fpr, tpr, label="AUC="+str(auroc))
    
            return {'Precision': precision[0], 'Recall': recall[0], 'Accuracy': accuracy,
                    'F-Score': f_score[0], 'AUROC': auroc}
    
        def save_model(self):
            # Saving Model
            name = self._model_name
    
            # Set paths for plot files if not existing already
            model_dir = self._dir + '/trained models'
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
    
            torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt')
            torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt')
    
        # def _classify(self):
        #     pass
    
    
    def read_training_data(directory, normalized=True):
        # Get training dataset from saved file and map to Torch tensor and dataset
        input_file = directory + ('/normalized_input_data.npy' if normalized else '/input_data.npy')
        output_file = directory + '/output_data.npy'
        return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file))))
    
    
    def evaluate_models(models, directory, num_iterations=100, colors=None,
                        compare_normalization=False):
        if colors is None:
            colors = {'Accuracy': 'red', 'Precision': 'yellow', 'Recall': 'blue',
                      'F-Score': 'green', 'AUROC': 'purple'}
    
        datasets = {'normalized': read_training_data(directory)}
        if compare_normalization:
            datasets['raw'] = read_training_data(directory, False)
        classification_stats = {measure: {model + ' (' + dataset + ')': [] for model in models
                                          for dataset in datasets} for measure in colors}
        for iteration in range(num_iterations):
            for train_index, test_index in KFold(
                    n_splits=5, shuffle=True).split(datasets['normalized']):
                # print("TRAIN:", train_index, "TEST:", test_index)
                for dataset in datasets.keys():
                    training_set = TensorDataset(*datasets[dataset][train_index])
                    test_set = datasets[dataset][test_index]
    
                    for model in models:
                        result = models[model].test_model(training_set, test_set)
                        for measure in colors:
                            classification_stats[measure][model + ' (' + dataset + ')'].append(
                                result[measure])
    
        plot_boxplot(classification_stats, colors)
        classification_stats = {measure: {model + ' (' + dataset + ')': np.array(
            classification_stats[measure][model + ' (' + dataset + ')']).mean() for model in models
                                          for dataset in datasets} for measure in colors}
        plot_classification_accuracy(classification_stats, colors)
    
        # Set paths for plot files if not existing already
        plot_dir = directory + '/model evaluation'
        if not os.path.exists(plot_dir):
            os.makedirs(plot_dir)
    
        # Save plots
        for identifier in plt.get_figlabels():
            # Set path for figure directory if not existing already
            if not os.path.exists(plot_dir + '/' + identifier):
                os.makedirs(plot_dir + '/' + identifier)
    
            plt.figure(identifier)
            plt.savefig(plot_dir + '/' + identifier + '/' + '_'.join(models.keys()) + '.pdf')
    
    
    # Loss Functions: BCELoss, BCEWithLogitsLoss,
    # CrossEntropyLoss (not working), MSELoss (with reduction='sum')
    # Optimizer: Adam, SGD
    # trainer = ModelTrainer({'num_epochs': 1000})
    # trainer.epoch_training()
    # trainer.test_model()
    # trainer.save_model()