Skip to content
Snippets Groups Projects
Commit 108a3943 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added evaluation for all classes (if the measure allows).

parent 36e48e46
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ TODO: Optimize Snakefile-vs-config relation ...@@ -8,7 +8,7 @@ TODO: Optimize Snakefile-vs-config relation
TODO: Improve maximum selection runtime -> Done TODO: Improve maximum selection runtime -> Done
TODO: Change model output to binary -> Do? (changes training when applied in ANN_Model) TODO: Change model output to binary -> Do? (changes training when applied in ANN_Model)
TODO: Adapt TCD file to new classification TODO: Adapt TCD file to new classification
TODO: Add evaluation for all classes (recall, precision, fscore) TODO: Add evaluation for all classes (recall, precision, fscore) -> Done
TODO: Add documentation TODO: Add documentation
""" """
...@@ -129,8 +129,10 @@ class ModelTrainer(object): ...@@ -129,8 +129,10 @@ class ModelTrainer(object):
# print(roc) # print(roc)
# plt.plot(fpr, tpr, label="AUC="+str(auroc)) # plt.plot(fpr, tpr, label="AUC="+str(auroc))
return {'Precision': precision[0], 'Recall': recall[0], 'Accuracy': accuracy, return {'Precision_Smooth': precision[0], 'Precision_Troubled': precision[1],
'F-Score': f_score[0], 'AUROC': auroc} 'Recall_Smooth': recall[0], 'Recall_Troubled': recall[1],
'F-Score_Smooth': f_score[0], 'F-Score_Troubled': f_score[1],
'Accuracy': accuracy, 'AUROC': auroc}
def save_model(self): def save_model(self):
# Saving Model # Saving Model
...@@ -158,8 +160,10 @@ def read_training_data(directory, normalized=True): ...@@ -158,8 +160,10 @@ def read_training_data(directory, normalized=True):
def evaluate_models(models, directory, num_iterations=100, colors=None, def evaluate_models(models, directory, num_iterations=100, colors=None,
compare_normalization=False): compare_normalization=False):
if colors is None: if colors is None:
colors = {'Accuracy': 'red', 'Precision': 'yellow', 'Recall': 'blue', colors = {'Accuracy': 'magenta', 'Precision_Smooth': 'red',
'F-Score': 'green', 'AUROC': 'purple'} 'Precision_Troubled': '#8B0000', 'Recall_Smooth': 'blue',
'Recall_Troubled': '#00008B', 'F-Score_Smooth': 'green',
'F-Score_Troubled': '#006400', 'AUROC': 'yellow'}
datasets = {'normalized': read_training_data(directory)} datasets = {'normalized': read_training_data(directory)}
if compare_normalization: if compare_normalization:
......
...@@ -26,11 +26,14 @@ functions: ...@@ -26,11 +26,14 @@ functions:
# Parameter for Model Training and Evaluation # Parameter for Model Training and Evaluation
compare_normalization: True compare_normalization: True
classification_colors: classification_colors:
Accuracy: 'magenta' Accuracy: '#FF00FF' # magenta
Precision: 'red' Precision_Smooth: '#FF0000' # red
Recall: 'tan' Precision_Troubled: '#8B0000' # dark red
F-Score: 'green' Recall_Smooth: '#0000FF' # blue
AUROC: 'yellow' Recall_Troubled: '#00008B' # dark blue
F-Score_Smooth: '#00FF00' # green
F-Score_Troubled: '#006400' # dark green
AUROC: '#FFFF00' # yellow
models: models:
Adam: Adam:
num_epochs: 1000 num_epochs: 1000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment