diff --git a/ANN_Training.py b/ANN_Training.py index 8a7ba02b30e16e89ef3a9f91fc4712bca117e6b5..9e9d58d0281a6120e77184f6a32cb455238bfade 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -4,11 +4,13 @@ TODO: Give option to compare multiple models TODO: Use sklearn for classification -TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) +TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.) TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation +TODO: Add boxplot over CFV +TODO: Improve maximum selection runtime """ import numpy as np @@ -17,7 +19,8 @@ import os import torch from torch.utils.data import TensorDataset, DataLoader, random_split from sklearn.model_selection import KFold -# from sklearn.metrics import accuracy_score, precision_recall_fscore_support +# from sklearn.metrics import accuracy_score +from sklearn.metrics import accuracy_score, precision_recall_fscore_support import ANN_Model from Plotting import plot_classification_accuracy @@ -142,7 +145,9 @@ class ModelTrainer(object): x_test, y_test = test_set # print(self._model(x_test.float())) - model_output = torch.round(self._model(x_test.float())) + model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] + for value in torch.max(self._model(x_test.float()), 1)[1]]) + # print(type(model_output), model_output) # acc = np.sum(model_output.numpy() == y_test.numpy()) # test_accuracy = (model_output == y_test).float().mean()