From eb7813bf57b09b67d969215a147073f14c4835b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Tue, 30 Nov 2021 15:51:45 +0100
Subject: [PATCH] Replaced classification with scikit-learn.

---
 ANN_Training.py | 66 +++++++++----------------------------------------
 1 file changed, 12 insertions(+), 54 deletions(-)

diff --git a/ANN_Training.py b/ANN_Training.py
index 9e9d58d..ff943a5 100644
--- a/ANN_Training.py
+++ b/ANN_Training.py
@@ -3,7 +3,7 @@
 @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
 
 TODO: Give option to compare multiple models
-TODO: Use sklearn for classification
+TODO: Use sklearn for classification -> Done
 TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done
 TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.)
 TODO: Add log to pipeline
@@ -11,6 +11,7 @@ TODO: Remove object set-up
 TODO: Optimize Snakefile-vs-config relation
 TODO: Add boxplot over CFV
 TODO: Improve maximum selection runtime
+TODO: Fix tensor mapping warning
 
 """
 import numpy as np
@@ -20,7 +21,7 @@ import torch
 from torch.utils.data import TensorDataset, DataLoader, random_split
 from sklearn.model_selection import KFold
 # from sklearn.metrics import accuracy_score
-from sklearn.metrics import accuracy_score, precision_recall_fscore_support
+from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score
 
 import ANN_Model
 from Plotting import plot_classification_accuracy
@@ -147,59 +148,16 @@ class ModelTrainer(object):
         # print(self._model(x_test.float()))
         model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
                                      for value in torch.max(self._model(x_test.float()), 1)[1]])
-        # print(type(model_output), model_output)
-
-        # acc = np.sum(model_output.numpy() == y_test.numpy())
-        # test_accuracy = (model_output == y_test).float().mean()
-        # print(test_accuracy)
-        # print(model_output.nelement())
-        # accuracy1 = torch.sum(torch.eq(model_output, y_test)).item()  # /model_output.nelement()
-        # print(test_accuracy, accuracy1/model_output.nelement())
-        # print(accuracy1)
-
-        tp, fp, tn, fn = self._evaluate_classification(model_output, y_test)
-        precision, recall, accuracy = self._evaluate_stats(tp, fp, tn, fn)
+
+        y_true = y_test.detach().numpy()
+        y_pred = model_output.detach().numpy()
+        accuracy = accuracy_score(y_true, y_pred)
+        # print('sklearn', accuracy)
+        precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
         # print(precision, recall)
-        # print(accuracy)
-
-        return [precision, recall, accuracy]
-
-    @staticmethod
-    def _evaluate_classification(model_output, true_output):
-        # Positive being Discontinuous/Troubled Cells, Negative being Smooth/Good Cells
-        true_positive = 0
-        true_negative = 0
-        false_positive = 0
-        false_negative = 0
-        for i in range(true_output.size()[0]):
-            if model_output[i, 1] == model_output[i, 0]:
-                print(i, model_output[i])
-            if true_output[i, 0] == torch.tensor([1]):
-                if model_output[i, 0] == true_output[i, 0]:
-                    true_positive += 1
-                else:
-                    false_negative += 1
-            if true_output[i, 1] == torch.tensor([1]):
-                if model_output[i, 1] == true_output[i, 1]:
-                    true_negative += 1
-                else:
-                    false_positive += 1
-
-        return true_positive, true_negative, false_positive, false_negative
-
-    @staticmethod
-    def _evaluate_stats(true_positive, true_negative, false_positive, false_negative):
-        if true_positive+false_positive == 0:
-            precision = 0
-            recall = 0
-        else:
-            precision = true_positive / (true_positive+false_positive)
-            recall = true_positive / (true_positive+false_negative)
-
-        accuracy = (true_positive+true_negative) / (true_positive+true_negative
-                                                    + false_positive+false_negative)
-        # print(true_positive+true_negative+false_positive+false_negative)
-        return precision, recall, accuracy
+        # print()
+
+        return [precision[0], recall[0], accuracy]
 
     def save_model(self):
         # Saving Model
-- 
GitLab