Skip to content
Snippets Groups Projects
Commit eb7813bf authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Replaced classification with scikit-learn.

parent 29e37136
Branches
Tags
No related merge requests found
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab) @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
TODO: Give option to compare multiple models TODO: Give option to compare multiple models
TODO: Use sklearn for classification TODO: Use sklearn for classification -> Done
TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done
TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.) TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.)
TODO: Add log to pipeline TODO: Add log to pipeline
...@@ -11,6 +11,7 @@ TODO: Remove object set-up ...@@ -11,6 +11,7 @@ TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation TODO: Optimize Snakefile-vs-config relation
TODO: Add boxplot over CFV TODO: Add boxplot over CFV
TODO: Improve maximum selection runtime TODO: Improve maximum selection runtime
TODO: Fix tensor mapping warning
""" """
import numpy as np import numpy as np
...@@ -20,7 +21,7 @@ import torch ...@@ -20,7 +21,7 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, random_split from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
# from sklearn.metrics import accuracy_score # from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score
import ANN_Model import ANN_Model
from Plotting import plot_classification_accuracy from Plotting import plot_classification_accuracy
...@@ -147,59 +148,16 @@ class ModelTrainer(object): ...@@ -147,59 +148,16 @@ class ModelTrainer(object):
# print(self._model(x_test.float())) # print(self._model(x_test.float()))
model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
for value in torch.max(self._model(x_test.float()), 1)[1]]) for value in torch.max(self._model(x_test.float()), 1)[1]])
# print(type(model_output), model_output)
y_true = y_test.detach().numpy()
# acc = np.sum(model_output.numpy() == y_test.numpy()) y_pred = model_output.detach().numpy()
# test_accuracy = (model_output == y_test).float().mean() accuracy = accuracy_score(y_true, y_pred)
# print(test_accuracy) # print('sklearn', accuracy)
# print(model_output.nelement()) precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
# accuracy1 = torch.sum(torch.eq(model_output, y_test)).item() # /model_output.nelement()
# print(test_accuracy, accuracy1/model_output.nelement())
# print(accuracy1)
tp, fp, tn, fn = self._evaluate_classification(model_output, y_test)
precision, recall, accuracy = self._evaluate_stats(tp, fp, tn, fn)
# print(precision, recall) # print(precision, recall)
# print(accuracy) # print()
return [precision, recall, accuracy] return [precision[0], recall[0], accuracy]
@staticmethod
def _evaluate_classification(model_output, true_output):
# Positive being Discontinuous/Troubled Cells, Negative being Smooth/Good Cells
true_positive = 0
true_negative = 0
false_positive = 0
false_negative = 0
for i in range(true_output.size()[0]):
if model_output[i, 1] == model_output[i, 0]:
print(i, model_output[i])
if true_output[i, 0] == torch.tensor([1]):
if model_output[i, 0] == true_output[i, 0]:
true_positive += 1
else:
false_negative += 1
if true_output[i, 1] == torch.tensor([1]):
if model_output[i, 1] == true_output[i, 1]:
true_negative += 1
else:
false_positive += 1
return true_positive, true_negative, false_positive, false_negative
@staticmethod
def _evaluate_stats(true_positive, true_negative, false_positive, false_negative):
if true_positive+false_positive == 0:
precision = 0
recall = 0
else:
precision = true_positive / (true_positive+false_positive)
recall = true_positive / (true_positive+false_negative)
accuracy = (true_positive+true_negative) / (true_positive+true_negative
+ false_positive+false_negative)
# print(true_positive+true_negative+false_positive+false_negative)
return precision, recall, accuracy
def save_model(self): def save_model(self):
# Saving Model # Saving Model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment