From def8763676096f736924a4749111682e101b1836 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 30 Nov 2021 22:02:34 +0100 Subject: [PATCH] Fixed tensor mapping warning. --- ANN_Training.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ANN_Training.py b/ANN_Training.py index ff943a5..209db3d 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -5,13 +5,12 @@ TODO: Give option to compare multiple models TODO: Use sklearn for classification -> Done TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done -TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.) +TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation -TODO: Add boxplot over CFV TODO: Improve maximum selection runtime -TODO: Fix tensor mapping warning +TODO: Fix tensor mapping warning -> Done """ import numpy as np @@ -115,7 +114,7 @@ class ModelTrainer(object): dataset = self._training_data for train_index, test_index in KFold(n_splits=5, shuffle=True).split(dataset): # print("TRAIN:", train_index, "TEST:", test_index) - training_set = TensorDataset(*map(torch.tensor, (dataset[train_index]))) + training_set = TensorDataset(*dataset[train_index]) test_set = dataset[test_index] classification_stats.append(self._test_fold(training_set, test_set)) -- GitLab