From def8763676096f736924a4749111682e101b1836 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Tue, 30 Nov 2021 22:02:34 +0100
Subject: [PATCH] Fixed tensor mapping warning.

---
 ANN_Training.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/ANN_Training.py b/ANN_Training.py
index ff943a5..209db3d 100644
--- a/ANN_Training.py
+++ b/ANN_Training.py
@@ -5,13 +5,12 @@
 TODO: Give option to compare multiple models
 TODO: Use sklearn for classification -> Done
 TODO: Fix difference between accuracies (stems from rounding; choose higher value instead) -> Done
-TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, etc.)
+TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.)
 TODO: Add log to pipeline
 TODO: Remove object set-up
 TODO: Optimize Snakefile-vs-config relation
-TODO: Add boxplot over CFV
 TODO: Improve maximum selection runtime
-TODO: Fix tensor mapping warning
+TODO: Fix tensor mapping warning -> Done
 
 """
 import numpy as np
@@ -115,7 +114,7 @@ class ModelTrainer(object):
             dataset = self._training_data
             for train_index, test_index in KFold(n_splits=5, shuffle=True).split(dataset):
                 # print("TRAIN:", train_index, "TEST:", test_index)
-                training_set = TensorDataset(*map(torch.tensor, (dataset[train_index])))
+                training_set = TensorDataset(*dataset[train_index])
                 test_set = dataset[test_index]
 
                 classification_stats.append(self._test_fold(training_set, test_set))
-- 
GitLab