diff --git a/ANN_Training.py b/ANN_Training.py index 191efef5dec6b6c52abbf53206a0f8d899deb7e5..09ee01800ad334426fb58998259c47ffc76bf244 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -3,7 +3,7 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Give option to compare multiple models -TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) +TODO: Add more evaluation measures (AUROC, ROC, F1, training accuracy, boxplot over CVF, etc.) -> Done TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation @@ -23,7 +23,7 @@ from sklearn.model_selection import KFold from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve import ANN_Model -from Plotting import plot_classification_accuracy +from Plotting import plot_classification_accuracy, plot_boxplot class ModelTrainer(object): @@ -121,6 +121,7 @@ class ModelTrainer(object): # print(classification_stats) # print(np.array(classification_stats).mean(axis=0)) + plot_boxplot([self._model_name], *np.array(classification_stats).transpose()) classification_stats = np.array(classification_stats).mean(axis=0) plot_classification_accuracy([self._model_name], *classification_stats) diff --git a/Plotting.py b/Plotting.py index 54f8753304b7e6c644aae5c03fe1be93b77ee81e..813ecb89f25d407a2ef68ee5b42810f89b40a876 100644 --- a/Plotting.py +++ b/Plotting.py @@ -259,7 +259,7 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, a fscore = [fscore] auroc = [auroc] pos = np.arange(len(xlabels)) - width = 0.3 + width = 1/(3*len(xlabels)) fig = plt.figure('classification_accuracy') ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) ax.bar(pos - 2*width, fscore, width, label='F-Score') @@ -270,8 +270,45 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, a ax.set_xticks(pos) ax.set_xticklabels(xlabels) ax.set_ylabel('Classification (%)') - ax.set_ylim(bottom=0.6) + ax.set_ylim(bottom=-0.02) ax.set_ylim(top=1.02) ax.set_title('Non-Normalized Test Data') ax.legend(loc='upper right') # fig.tight_layout() + + +def plot_boxplot(xlabels, precision, recall, accuracy, fscore, auroc): + precision = [precision] + recall = [recall] + accuracy = [accuracy] + fscore = [fscore] + auroc = [auroc] + fig = plt.figure('boxplot_accuracy') + pos = np.arange(len(xlabels)) + width = 1/(5*len(xlabels)) + ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) + boxplots = [] + boxplots.append(ax.boxplot(fscore, positions=pos - 3*width, widths=width, meanline=True, + showmeans=True, patch_artist=True)) + boxplots.append(ax.boxplot(precision, positions=pos - 1.5*width, widths=width, meanline=True, + showmeans=True, patch_artist=True)) + boxplots.append(ax.boxplot(recall, positions=pos, widths=width, meanline=True, showmeans=True, + patch_artist=True)) + boxplots.append(ax.boxplot(accuracy, positions=pos + 1.5*width, widths=width, meanline=True, + showmeans=True, patch_artist=True)) + boxplots.append(ax.boxplot(auroc, positions=pos + 3*width, widths=width, meanline=True, + showmeans=True, patch_artist=True)) + count = 0 + colors = ['red', 'yellow', 'blue', 'tan', 'green'] + for bp in boxplots: + for patch in bp['boxes']: + patch.set(facecolor=colors[count]) + count +=1 + ax.set_xticks(pos) + ax.set_xticklabels(xlabels) + ax.set_ylim(bottom=-0.02) + ax.set_ylim(top=1.02) + ax.set_ylabel('Classification (%)') + ax.set_title('Non-Normalized Test Data') + ax.legend([bp["boxes"][0] for bp in boxplots], + ['F-Score', 'Precision', 'Recall', 'Accuracy', 'AUROC'], loc='upper right')