From dae614165a6868e8932e23708619021d0d38ce6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 7 Dec 2021 23:01:16 +0100 Subject: [PATCH] Removed unnecessary imports. --- ANN_Training.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/ANN_Training.py b/ANN_Training.py index ecd95b8..892503a 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -3,16 +3,25 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Give option to compare multiple models -> Done -TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) -> Done +TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) + -> Done TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation TODO: Improve maximum selection runtime -TODO: Discuss if we want training accuracy/ROC in addition to CFV -TODO: Discuss whether to change output to binary +TODO: Discuss if we want training accuracy/ROC in addition to CFV -> Done (No) +TODO: Discuss whether to change output to binary -> Done (Yes) +TODO: Change output to binary TODO: Adapt TCD file to new classification TODO: Improve classification stat handling -> Done TODO: Discuss automatic comparison between (non-)normalized data + -> Done (Flag for comparison) +TODO: Add flag for evaluation of non-normalized data as well -> Next! +TODO: Add evaluation for all classes (recall, precision, fscore) +TODO: Add documentation +TODO: Separate model training in Snakefile by using wildcards -> Done +TODO: Correct import statements -> Done (Installed new version) +TODO: Remove unnecessary imports -> Done """ import numpy as np @@ -21,8 +30,7 @@ import os import torch from torch.utils.data import TensorDataset, DataLoader, random_split from sklearn.model_selection import KFold -# from sklearn.metrics import accuracy_score -from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve +from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score import ANN_Model from Plotting import plot_classification_accuracy, plot_boxplot @@ -156,7 +164,7 @@ def read_training_data(directory): return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file)))) -def evaluate_models(models, directory, num_iterations=100, colors = None): +def evaluate_models(models, directory, num_iterations=100, colors=None): if colors is None: colors = {'Accuracy': 'red', 'Precision': 'yellow', 'Recall': 'blue', 'F-Score': 'green', 'AUROC': 'purple'} -- GitLab