From a34d6f4f23ecadece068b6a7709706b7b7e8ee0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 15 Feb 2022 17:20:24 +0100 Subject: [PATCH] Added saving of model evaluation data. --- ANN_Training.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ANN_Training.py b/ANN_Training.py index a2ce692..638f94d 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -6,11 +6,14 @@ Code-Style: E226, W503 Docstring-Style: D200, D400 TODO: Test new ANN set-up with Soraya -TODO: Remove object set-up (for more flexibility) -> Done (decided against it to keep easy test set-up) +TODO: Remove object set-up (for more flexibility) -> Done + (decided against it to keep easy test set-up) TODO: Add documentation TODO: Allow multiple approximations in one config -> Done TODO: Split workflow into multiple modules -> Done TODO: Remove unnecessary instance variables -> Done +TODO: Add option to change 'num_iterations' for model testing -> Done +TODO: Save model evaluation data -> Done TODO: Add README for ANN training TODO: Fix random seed TODO: Discuss whether to outsource scripts into separate directory @@ -24,6 +27,7 @@ import matplotlib from matplotlib import pyplot as plt import os import torch +import json from torch.utils.data import TensorDataset, DataLoader, random_split from sklearn.model_selection import KFold from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score @@ -196,6 +200,11 @@ def evaluate_models(models, directory, num_iterations=100, colors=None, print('Finished training models with 5-fold cross validation!') print(f'Training time: {toc_train - tic_train:0.4f}s\n') + with open(directory + '/' + '_'.join(models.keys()) + '.json', 'w') as json_file: + json_file.write(json.dumps(classification_stats)) + with open(directory + '/' + '_'.join(models.keys()) + '.json') as json_file: + classification_stats = json.load(json_file) + print('Plotting evaluation of trained models.') plot_boxplot(classification_stats, colors) classification_stats = {measure: {model + ' (' + dataset + ')': np.array( -- GitLab