Skip to content
Snippets Groups Projects
Commit bab035e0 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added f-score and AUROC to classification evaluation.

parent 0e2f343d
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,9 @@ TODO: Add log to pipeline
TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation
TODO: Improve maximum selection runtime
TODO: Discuss if we want training accuracy/ROC in addition to CFV
TODO: Discuss whether to change output to binary
TODO: Adapt TCD file to new classification
"""
import numpy as np
......@@ -17,7 +20,7 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.model_selection import KFold
# from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve
import ANN_Model
from Plotting import plot_classification_accuracy
......@@ -142,18 +145,28 @@ class ModelTrainer(object):
x_test, y_test = test_set
# print(self._model(x_test.float()))
model_score = self._model(x_test.float())
model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
for value in torch.max(self._model(x_test.float()), 1)[1]])
for value in torch.max(model_score, 1)[1]])
y_true = y_test.detach().numpy()
y_pred = model_output.detach().numpy()
y_true = y_test.detach().numpy()[:, 0]
y_pred = model_output.detach().numpy()[:, 0]
# y_score = model_score.detach().numpy()[:, 0]
accuracy = accuracy_score(y_true, y_pred)
# print('sklearn', accuracy)
precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
# print(precision, recall)
# print(precision, recall, f_score)
# print()
return [precision[0], recall[0], accuracy]
# auroc = roc_auc_score(y_true, y_score)
# print('auroc raw', auroc)
auroc = roc_auc_score(y_true, y_pred)
print('auroc true', auroc)
# fpr, tpr, thresholds = roc_curve(y_true, y_score)
# roc = [tpr, fpr, thresholds]
# print(roc)
# plt.plot(fpr, tpr, label="AUC="+str(auroc))
return [precision[0], recall[0], accuracy, f_score[0], auroc]
def save_model(self):
# Saving Model
......
......@@ -3,6 +3,7 @@
@author: Laura C. Kühle
TODO: Give option to select plotting color
TODO: Improve classification plotting
"""
import numpy as np
......@@ -235,7 +236,7 @@ def calculate_exact_solution(mesh, cell_len, wave_speed, final_time, interval_le
return grid, exact
def plot_classification_accuracy(xlabels, precision, recall, accuracy):
def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, auroc):
"""Plots classification accuracy.
Plots the accuracy, precision, and recall in a bar plot.
......@@ -255,13 +256,17 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy):
precision = [precision]
recall = [recall]
accuracy = [accuracy]
fscore = [fscore]
auroc = [auroc]
pos = np.arange(len(xlabels))
width = 0.3
fig = plt.figure('classification_accuracy')
ax = fig.add_axes([0.15, 0.1, 0.75, 0.8])
ax.bar(pos - 2*width, fscore, width, label='F-Score')
ax.bar(pos - width, precision, width, label='Precision')
ax.bar(pos, recall, width, label='Recall')
ax.bar(pos + width, accuracy, width, label='Accuracy')
ax.bar(pos + 2*width, auroc, width, label='AUROC')
ax.set_xticks(pos)
ax.set_xticklabels(xlabels)
ax.set_ylabel('Classification (%)')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment