diff --git a/ANN_Training.py b/ANN_Training.py index ff943a508a1f24b54beef157ddd8d5d152c0f0a3..209db3d9421f44f0464d431fe9bf240bb734534f 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -5,13 +5,12 @@ TODO: Give option to compare multiple models TODO: Use sklearn for classification -> Done TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done -TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.) +TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation -TODO: Add boxplot over CFV TODO: Improve maximum selection runtime -TODO: Fix tensor mapping warning +TODO: Fix tensor mapping warning -> Done """ import numpy as np @@ -115,7 +114,7 @@ class ModelTrainer(object): dataset = self._training_data for train_index, test_index in KFold(n_splits=5, shuffle=True).split(dataset): # print("TRAIN:", train_index, "TEST:", test_index) - training_set = TensorDataset(*map(torch.tensor, (dataset[train_index]))) + training_set = TensorDataset(*dataset[train_index]) test_set = dataset[test_index] classification_stats.append(self._test_fold(training_set, test_set))