diff --git a/ANN_Training.py b/ANN_Training.py index a2ce6923f8acd9db0d5fd98f2cde367fe595939d..638f94de6a85350fa502327f477af47094b5a118 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -6,11 +6,14 @@ Code-Style: E226, W503 Docstring-Style: D200, D400 TODO: Test new ANN set-up with Soraya -TODO: Remove object set-up (for more flexibility) -> Done (decided against it to keep easy test set-up) +TODO: Remove object set-up (for more flexibility) -> Done + (decided against it to keep easy test set-up) TODO: Add documentation TODO: Allow multiple approximations in one config -> Done TODO: Split workflow into multiple modules -> Done TODO: Remove unnecessary instance variables -> Done +TODO: Add option to change 'num_iterations' for model testing -> Done +TODO: Save model evaluation data -> Done TODO: Add README for ANN training TODO: Fix random seed TODO: Discuss whether to outsource scripts into separate directory @@ -24,6 +27,7 @@ import matplotlib from matplotlib import pyplot as plt import os import torch +import json from torch.utils.data import TensorDataset, DataLoader, random_split from sklearn.model_selection import KFold from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score @@ -196,6 +200,11 @@ def evaluate_models(models, directory, num_iterations=100, colors=None, print('Finished training models with 5-fold cross validation!') print(f'Training time: {toc_train - tic_train:0.4f}s\n') + with open(directory + '/' + '_'.join(models.keys()) + '.json', 'w') as json_file: + json_file.write(json.dumps(classification_stats)) + with open(directory + '/' + '_'.join(models.keys()) + '.json') as json_file: + classification_stats = json.load(json_file) + print('Plotting evaluation of trained models.') plot_boxplot(classification_stats, colors) classification_stats = {measure: {model + ' (' + dataset + ')': np.array(