Skip to content
Snippets Groups Projects
Select Git revision
  • 28971c8a6ee92485d52f6a3277b4de10ced97919
  • master default protected
2 results

workerInterface.js

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    ANN_Training.py 8.76 KiB
    # -*- coding: utf-8 -*-
    """
    @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
    
    TODO: Add ANN classification from Soraya
    TODO: Give option to compare multiple models
    TODO: Use sklearn for classification
    TODO: Fix difference between accuracies (stems from rounding; choose higher value instead)
    TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.)
    TODO: Decide on k-fold cross-validation (Use? Which model do we keep?)
    TODO: Rework model testing
    TODO: Clean up directories/naming
    TODO: Add log to pipeline
    
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    import torch
    from torch.utils.data import TensorDataset, DataLoader, random_split
    # from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    
    import ANN_Model
    from Plotting import plot_classification_accuracy
    
    
    class ModelTrainer(object):
        def __init__(self, config):
            self._reset(config)
    
        def _reset(self, config):
            data_dir = config.pop('data_dir', 'test_data')
            self._model_dir = config.pop('model_dir', 'test_data')
            self._plot_dir = config.pop('plot_dir', 'new_fig')
            self._data_file = config.pop('training_data', 'smooth_0.05k__troubled_0.05k__normalized.npy')
            self._read_training_data(data_dir)
    
            self._batch_size = config.pop('batch_size', min(len(self._training_data)//2, 500))
            self._num_epochs = config.pop('num_epochs', 1000)
            self._threshold = config.pop('threshold', 1e-5)
            self._model = config.pop('model', 'ThreeLayerReLu')
            self._model_config = config.pop('model_config', {})
            self._loss_function = config.pop('loss_function', 'BCELoss')
            self._loss_config = config.pop('loss_config', {})
            self._optimizer = config.pop('optimizer', 'Adam')
            self._optimizer_config = config.pop('optimizer_config', {})
    
            # Set learning rate
            self._learning_rate = config.pop('learning_rate', 1e-2)
            self._optimizer_config['lr'] = self._learning_rate
    
            if not hasattr(ANN_Model, self._model):
                raise ValueError('Invalid model: "%s"' % self._model)
            if not hasattr(torch.nn.modules.loss, self._loss_function):
                raise ValueError('Invalid loss function: "%s"' % self._loss_function)
            if not hasattr(torch.optim, self._optimizer):
                raise ValueError('Invalid optimizer: "%s"' % self._optimizer)
    
            self._model = getattr(ANN_Model, self._model)(self._model_config)
            self._loss_function = getattr(torch.nn.modules.loss, self._loss_function)(
                **self._loss_config)
            self._optimizer = getattr(torch.optim, self._optimizer)(
                self._model.parameters(), **self._optimizer_config)
            self._validation_loss = torch.zeros(self._num_epochs//100)
    
        def _read_training_data(self, directory):
            # Get training dataset from saved file and map to Torch tensor and dataset
            input_file = directory + '/training_input.npy'  # + self._data_file
            output_file = directory + '/training_output.npy'  # + self._data_file
            self._training_data = TensorDataset(*map(torch.tensor, (np.load(input_file),
                                                                    np.load(output_file))))
    
        def epoch_training(self):
            # Split data into training and validation set
            dataset = self._training_data
            num_samples = len(dataset)
            train_ds, valid_ds = random_split(dataset, [round(num_samples*0.8), round(num_samples*0.2)])
    
            # Load sets
            train_dl = DataLoader(train_ds, batch_size=self._batch_size, shuffle=True)
            valid_dl = DataLoader(valid_ds, batch_size=self._batch_size * 2)
    
            # Training with Validation
            for epoch in range(self._num_epochs):
                self._model.train()
                for x_batch, y_batch in train_dl:
                    pred = self._model(x_batch.float())
                    loss = self._loss_function(pred, y_batch.float()).mean()
    
                    # Run back propagation, update the weights, and zero gradients for next epoch
                    loss.backward()
                    self._optimizer.step()
                    self._optimizer.zero_grad()
    
                self._model.eval()
                with torch.no_grad():
                    valid_loss = sum(
                        self._loss_function(self._model(x_batch_valid.float()), y_batch_valid.float())
                        for x_batch_valid, y_batch_valid in valid_dl)
    
                    if (epoch+1) % 100 == 0:
                        self._validation_loss[int((epoch+1) / 100)-1] = valid_loss / len(valid_dl)
                        print(epoch+1, valid_loss / len(valid_dl))
    
                    if valid_loss / len(valid_dl) < self._threshold:
                        break
    
        def test_model(self):
            self.epoch_training()
            self._model.eval()
    
            x_test, y_test = self._training_data['test']
            print(self._model(x_test.float()))
            model_output = torch.round(self._model(x_test.float()))
            # acc = np.sum(model_output.numpy() == y_test.numpy())
            # test_accuracy = (model_output == y_test).float().mean()
            # print(test_accuracy)
            # print(model_output.nelement())
            # accuracy1 = torch.sum(torch.eq(model_output, y_test)).item()  # /model_output.nelement()
            # print(test_accuracy, accuracy1/model_output.nelement())
            # print(accuracy1)
    
            tp, fp, tn, fn = self._evaluate_classification(model_output, y_test)
            precision, recall, accuracy = self._evaluate_stats(tp, fp, tn, fn)
            # print(precision, recall)
            # print(accuracy)
            plot_classification_accuracy(precision, recall, accuracy, ['real'])
    
            # Set file name
            test_name = self._validation_file.split('.npy')[0]
            name = self._model.get_name() + '__' + self._optimizer.__class__.__name__ + '_' \
                + str(self._learning_rate) + '__' + self._loss_function.__class__.__name__ + '__' \
                + test_name
    
            # Set paths for plot files if not existing already
            if not os.path.exists(self._plot_dir):
                os.makedirs(self._plot_dir)
    
            # Save plots
            for identifier in plt.get_figlabels():
                # Set path for figure directory if not existing already
                if not os.path.exists(self._plot_dir + '/' + identifier):
                    os.makedirs(self._plot_dir + '/' + identifier)
    
                plt.figure(identifier)
                plt.savefig(self._plot_dir + '/' + identifier + '/' + name + '.pdf')
    
        @staticmethod
        def _evaluate_classification(model_output, true_output):
            # Positive being Discontinuous/Troubled Cells, Negative being Smooth/Good Cells
            true_positive = 0
            true_negative = 0
            false_positive = 0
            false_negative = 0
            for i in range(true_output.size()[0]):
                if model_output[i, 1] == model_output[i, 0]:
                    print(i, model_output[i])
                if true_output[i, 0] == torch.tensor([1]):
                    if model_output[i, 0] == true_output[i, 0]:
                        true_positive += 1
                    else:
                        false_negative += 1
                if true_output[i, 1] == torch.tensor([1]):
                    if model_output[i, 1] == true_output[i, 1]:
                        true_negative += 1
                    else:
                        false_positive += 1
    
            return true_positive, true_negative, false_positive, false_negative
    
        @staticmethod
        def _evaluate_stats(true_positive, true_negative, false_positive, false_negative):
            if true_positive+false_positive == 0:
                precision = 0
                recall = 0
            else:
                precision = true_positive / (true_positive+false_positive)
                recall = true_positive / (true_positive+false_negative)
    
            accuracy = (true_positive+true_negative) / (true_positive+true_negative
                                                        + false_positive+false_negative)
            # print(true_positive+true_negative+false_positive+false_negative)
            return precision, recall, accuracy
    
        def save_model(self):
            # Saving Model
            # data_name = self._data_file.split('.npy')[0]
            # path = self._model.get_name() + '__' + self._optimizer.__class__.__name__ + '_' \
            #     + str(self._learning_rate) + '__' + self._loss_function.__class__.__name__ + '.pt'
    
            # Set paths for plot files if not existing already
            if not os.path.exists(self._model_dir):
                os.makedirs(self._model_dir)
    
            torch.save(self._model.state_dict(), self._model_dir + '/model.pt')  # __' + path)
            torch.save(self._validation_loss, self._model_dir + '/loss.pt')  # __' + path)
    
        def _classify(self):
            pass
    
    
    # Loss Functions: BCELoss, BCEWithLogitsLoss,
    # CrossEntropyLoss (not working), MSELoss (with reduction='sum')
    # Optimizer: Adam, SGD
    # trainer = ModelTrainer({'num_epochs': 1000})
    # trainer.epoch_training()
    # trainer.test_model()
    # trainer.save_model()