Skip to content
Snippets Groups Projects
Commit a34d6f4f authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added saving of model evaluation data.

parent dc854096
Branches
No related tags found
No related merge requests found
...@@ -6,11 +6,14 @@ Code-Style: E226, W503 ...@@ -6,11 +6,14 @@ Code-Style: E226, W503
Docstring-Style: D200, D400 Docstring-Style: D200, D400
TODO: Test new ANN set-up with Soraya TODO: Test new ANN set-up with Soraya
TODO: Remove object set-up (for more flexibility) -> Done (decided against it to keep easy test set-up) TODO: Remove object set-up (for more flexibility) -> Done
(decided against it to keep easy test set-up)
TODO: Add documentation TODO: Add documentation
TODO: Allow multiple approximations in one config -> Done TODO: Allow multiple approximations in one config -> Done
TODO: Split workflow into multiple modules -> Done TODO: Split workflow into multiple modules -> Done
TODO: Remove unnecessary instance variables -> Done TODO: Remove unnecessary instance variables -> Done
TODO: Add option to change 'num_iterations' for model testing -> Done
TODO: Save model evaluation data -> Done
TODO: Add README for ANN training TODO: Add README for ANN training
TODO: Fix random seed TODO: Fix random seed
TODO: Discuss whether to outsource scripts into separate directory TODO: Discuss whether to outsource scripts into separate directory
...@@ -24,6 +27,7 @@ import matplotlib ...@@ -24,6 +27,7 @@ import matplotlib
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import os import os
import torch import torch
import json
from torch.utils.data import TensorDataset, DataLoader, random_split from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
...@@ -196,6 +200,11 @@ def evaluate_models(models, directory, num_iterations=100, colors=None, ...@@ -196,6 +200,11 @@ def evaluate_models(models, directory, num_iterations=100, colors=None,
print('Finished training models with 5-fold cross validation!') print('Finished training models with 5-fold cross validation!')
print(f'Training time: {toc_train - tic_train:0.4f}s\n') print(f'Training time: {toc_train - tic_train:0.4f}s\n')
with open(directory + '/' + '_'.join(models.keys()) + '.json', 'w') as json_file:
json_file.write(json.dumps(classification_stats))
with open(directory + '/' + '_'.join(models.keys()) + '.json') as json_file:
classification_stats = json.load(json_file)
print('Plotting evaluation of trained models.') print('Plotting evaluation of trained models.')
plot_boxplot(classification_stats, colors) plot_boxplot(classification_stats, colors)
classification_stats = {measure: {model + ' (' + dataset + ')': np.array( classification_stats = {measure: {model + ' (' + dataset + ')': np.array(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment