From eb7813bf57b09b67d969215a147073f14c4835b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 30 Nov 2021 15:51:45 +0100 Subject: [PATCH] Replaced classification with scikit-learn. --- ANN_Training.py | 66 +++++++++---------------------------------------- 1 file changed, 12 insertions(+), 54 deletions(-) diff --git a/ANN_Training.py b/ANN_Training.py index 9e9d58d..ff943a5 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -3,7 +3,7 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Give option to compare multiple models -TODO: Use sklearn for classification +TODO: Use sklearn for classification -> Done TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.) TODO: Add log to pipeline @@ -11,6 +11,7 @@ TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation TODO: Add boxplot over CFV TODO: Improve maximum selection runtime +TODO: Fix tensor mapping warning """ import numpy as np @@ -20,7 +21,7 @@ import torch from torch.utils.data import TensorDataset, DataLoader, random_split from sklearn.model_selection import KFold # from sklearn.metrics import accuracy_score -from sklearn.metrics import accuracy_score, precision_recall_fscore_support +from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score import ANN_Model from Plotting import plot_classification_accuracy @@ -147,59 +148,16 @@ class ModelTrainer(object): # print(self._model(x_test.float())) model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] for value in torch.max(self._model(x_test.float()), 1)[1]]) - # print(type(model_output), model_output) - - # acc = np.sum(model_output.numpy() == y_test.numpy()) - # test_accuracy = (model_output == y_test).float().mean() - # print(test_accuracy) - # print(model_output.nelement()) - # accuracy1 = torch.sum(torch.eq(model_output, y_test)).item() # /model_output.nelement() - # print(test_accuracy, accuracy1/model_output.nelement()) - # print(accuracy1) - - tp, fp, tn, fn = self._evaluate_classification(model_output, y_test) - precision, recall, accuracy = self._evaluate_stats(tp, fp, tn, fn) + + y_true = y_test.detach().numpy() + y_pred = model_output.detach().numpy() + accuracy = accuracy_score(y_true, y_pred) + # print('sklearn', accuracy) + precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred) # print(precision, recall) - # print(accuracy) - - return [precision, recall, accuracy] - - @staticmethod - def _evaluate_classification(model_output, true_output): - # Positive being Discontinuous/Troubled Cells, Negative being Smooth/Good Cells - true_positive = 0 - true_negative = 0 - false_positive = 0 - false_negative = 0 - for i in range(true_output.size()[0]): - if model_output[i, 1] == model_output[i, 0]: - print(i, model_output[i]) - if true_output[i, 0] == torch.tensor([1]): - if model_output[i, 0] == true_output[i, 0]: - true_positive += 1 - else: - false_negative += 1 - if true_output[i, 1] == torch.tensor([1]): - if model_output[i, 1] == true_output[i, 1]: - true_negative += 1 - else: - false_positive += 1 - - return true_positive, true_negative, false_positive, false_negative - - @staticmethod - def _evaluate_stats(true_positive, true_negative, false_positive, false_negative): - if true_positive+false_positive == 0: - precision = 0 - recall = 0 - else: - precision = true_positive / (true_positive+false_positive) - recall = true_positive / (true_positive+false_negative) - - accuracy = (true_positive+true_negative) / (true_positive+true_negative - + false_positive+false_negative) - # print(true_positive+true_negative+false_positive+false_negative) - return precision, recall, accuracy + # print() + + return [precision[0], recall[0], accuracy] def save_model(self): # Saving Model -- GitLab