Skip to content
Snippets Groups Projects
Commit 9adbf4a6 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added boxplot over all cross-validation results to classification evaluation.

parent bab035e0
Branches
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab) @author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
TODO: Give option to compare multiple models TODO: Give option to compare multiple models
TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) -> Done
TODO: Add log to pipeline TODO: Add log to pipeline
TODO: Remove object set-up TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation TODO: Optimize Snakefile-vs-config relation
...@@ -23,7 +23,7 @@ from sklearn.model_selection import KFold ...@@ -23,7 +23,7 @@ from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve
import ANN_Model import ANN_Model
from Plotting import plot_classification_accuracy from Plotting import plot_classification_accuracy, plot_boxplot
class ModelTrainer(object): class ModelTrainer(object):
...@@ -121,6 +121,7 @@ class ModelTrainer(object): ...@@ -121,6 +121,7 @@ class ModelTrainer(object):
# print(classification_stats) # print(classification_stats)
# print(np.array(classification_stats).mean(axis=0)) # print(np.array(classification_stats).mean(axis=0))
plot_boxplot([self._model_name], *np.array(classification_stats).transpose())
classification_stats = np.array(classification_stats).mean(axis=0) classification_stats = np.array(classification_stats).mean(axis=0)
plot_classification_accuracy([self._model_name], *classification_stats) plot_classification_accuracy([self._model_name], *classification_stats)
......
...@@ -259,7 +259,7 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, a ...@@ -259,7 +259,7 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, a
fscore = [fscore] fscore = [fscore]
auroc = [auroc] auroc = [auroc]
pos = np.arange(len(xlabels)) pos = np.arange(len(xlabels))
width = 0.3 width = 1/(3*len(xlabels))
fig = plt.figure('classification_accuracy') fig = plt.figure('classification_accuracy')
ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) ax = fig.add_axes([0.15, 0.1, 0.75, 0.8])
ax.bar(pos - 2*width, fscore, width, label='F-Score') ax.bar(pos - 2*width, fscore, width, label='F-Score')
...@@ -270,8 +270,45 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, a ...@@ -270,8 +270,45 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, a
ax.set_xticks(pos) ax.set_xticks(pos)
ax.set_xticklabels(xlabels) ax.set_xticklabels(xlabels)
ax.set_ylabel('Classification (%)') ax.set_ylabel('Classification (%)')
ax.set_ylim(bottom=0.6) ax.set_ylim(bottom=-0.02)
ax.set_ylim(top=1.02) ax.set_ylim(top=1.02)
ax.set_title('Non-Normalized Test Data') ax.set_title('Non-Normalized Test Data')
ax.legend(loc='upper right') ax.legend(loc='upper right')
# fig.tight_layout() # fig.tight_layout()
def plot_boxplot(xlabels, precision, recall, accuracy, fscore, auroc):
precision = [precision]
recall = [recall]
accuracy = [accuracy]
fscore = [fscore]
auroc = [auroc]
fig = plt.figure('boxplot_accuracy')
pos = np.arange(len(xlabels))
width = 1/(5*len(xlabels))
ax = fig.add_axes([0.15, 0.1, 0.75, 0.8])
boxplots = []
boxplots.append(ax.boxplot(fscore, positions=pos - 3*width, widths=width, meanline=True,
showmeans=True, patch_artist=True))
boxplots.append(ax.boxplot(precision, positions=pos - 1.5*width, widths=width, meanline=True,
showmeans=True, patch_artist=True))
boxplots.append(ax.boxplot(recall, positions=pos, widths=width, meanline=True, showmeans=True,
patch_artist=True))
boxplots.append(ax.boxplot(accuracy, positions=pos + 1.5*width, widths=width, meanline=True,
showmeans=True, patch_artist=True))
boxplots.append(ax.boxplot(auroc, positions=pos + 3*width, widths=width, meanline=True,
showmeans=True, patch_artist=True))
count = 0
colors = ['red', 'yellow', 'blue', 'tan', 'green']
for bp in boxplots:
for patch in bp['boxes']:
patch.set(facecolor=colors[count])
count +=1
ax.set_xticks(pos)
ax.set_xticklabels(xlabels)
ax.set_ylim(bottom=-0.02)
ax.set_ylim(top=1.02)
ax.set_ylabel('Classification (%)')
ax.set_title('Non-Normalized Test Data')
ax.legend([bp["boxes"][0] for bp in boxplots],
['F-Score', 'Precision', 'Recall', 'Accuracy', 'AUROC'], loc='upper right')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment