From a34d6f4f23ecadece068b6a7709706b7b7e8ee0a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Tue, 15 Feb 2022 17:20:24 +0100
Subject: [PATCH] Added saving of model evaluation data.

---
 ANN_Training.py | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/ANN_Training.py b/ANN_Training.py
index a2ce692..638f94d 100644
--- a/ANN_Training.py
+++ b/ANN_Training.py
@@ -6,11 +6,14 @@ Code-Style: E226, W503
 Docstring-Style: D200, D400
 
 TODO: Test new ANN set-up with Soraya
-TODO: Remove object set-up (for more flexibility) -> Done (decided against it to keep easy test set-up)
+TODO: Remove object set-up (for more flexibility) -> Done
+    (decided against it to keep easy test set-up)
 TODO: Add documentation
 TODO: Allow multiple approximations in one config -> Done
 TODO: Split workflow into multiple modules -> Done
 TODO: Remove unnecessary instance variables -> Done
+TODO: Add option to change 'num_iterations' for model testing -> Done
+TODO: Save model evaluation data -> Done
 TODO: Add README for ANN training
 TODO: Fix random seed
 TODO: Discuss whether to outsource scripts into separate directory
@@ -24,6 +27,7 @@ import matplotlib
 from matplotlib import pyplot as plt
 import os
 import torch
+import json
 from torch.utils.data import TensorDataset, DataLoader, random_split
 from sklearn.model_selection import KFold
 from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
@@ -196,6 +200,11 @@ def evaluate_models(models, directory, num_iterations=100, colors=None,
     print('Finished training models with 5-fold cross validation!')
     print(f'Training time: {toc_train - tic_train:0.4f}s\n')
 
+    with open(directory + '/' + '_'.join(models.keys()) + '.json', 'w') as json_file:
+        json_file.write(json.dumps(classification_stats))
+    with open(directory + '/' + '_'.join(models.keys()) + '.json') as json_file:
+        classification_stats = json.load(json_file)
+
     print('Plotting evaluation of trained models.')
     plot_boxplot(classification_stats, colors)
     classification_stats = {measure: {model + ' (' + dataset + ')': np.array(
-- 
GitLab