Skip to content
Snippets Groups Projects
Commit 29e37136 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Replaced rounding with higher-value selection for classification.

parent 7db8d3d9
No related branches found
No related tags found
No related merge requests found
......@@ -4,11 +4,13 @@
TODO: Give option to compare multiple models
TODO: Use sklearn for classification
TODO: Fix difference between accuracies (stems from rounding; choose higher value instead)
TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done
TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.)
TODO: Add log to pipeline
TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation
TODO: Add boxplot over CFV
TODO: Improve maximum selection runtime
"""
import numpy as np
......@@ -17,7 +19,8 @@ import os
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support
# from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import ANN_Model
from Plotting import plot_classification_accuracy
......@@ -142,7 +145,9 @@ class ModelTrainer(object):
x_test, y_test = test_set
# print(self._model(x_test.float()))
model_output = torch.round(self._model(x_test.float()))
model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
for value in torch.max(self._model(x_test.float()), 1)[1]])
# print(type(model_output), model_output)
# acc = np.sum(model_output.numpy() == y_test.numpy())
# test_accuracy = (model_output == y_test).float().mean()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment