diff --git a/ANN_Training.py b/ANN_Training.py index 0f3afb7c3c0cb81e874a3c17d3dc923f4690e2dc..191efef5dec6b6c52abbf53206a0f8d899deb7e5 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -8,6 +8,9 @@ TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation TODO: Improve maximum selection runtime +TODO: Discuss if we want training accuracy/ROC in addition to CFV +TODO: Discuss whether to change output to binary +TODO: Adapt TCD file to new classification """ import numpy as np @@ -17,7 +20,7 @@ import torch from torch.utils.data import TensorDataset, DataLoader, random_split from sklearn.model_selection import KFold # from sklearn.metrics import accuracy_score -from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score +from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve import ANN_Model from Plotting import plot_classification_accuracy @@ -142,18 +145,28 @@ class ModelTrainer(object): x_test, y_test = test_set # print(self._model(x_test.float())) + model_score = self._model(x_test.float()) model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] - for value in torch.max(self._model(x_test.float()), 1)[1]]) + for value in torch.max(model_score, 1)[1]]) - y_true = y_test.detach().numpy() - y_pred = model_output.detach().numpy() + y_true = y_test.detach().numpy()[:, 0] + y_pred = model_output.detach().numpy()[:, 0] + # y_score = model_score.detach().numpy()[:, 0] accuracy = accuracy_score(y_true, y_pred) # print('sklearn', accuracy) precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred) - # print(precision, recall) + # print(precision, recall, f_score) # print() - - return [precision[0], recall[0], accuracy] + # auroc = roc_auc_score(y_true, y_score) + # print('auroc raw', auroc) + auroc = roc_auc_score(y_true, y_pred) + print('auroc true', auroc) + # fpr, tpr, thresholds = roc_curve(y_true, y_score) + # roc = [tpr, fpr, thresholds] + # print(roc) + # plt.plot(fpr, tpr, label="AUC="+str(auroc)) + + return [precision[0], recall[0], accuracy, f_score[0], auroc] def save_model(self): # Saving Model diff --git a/Plotting.py b/Plotting.py index 77c3bcdb028a7b68effd4d08c1956b70690a65ee..54f8753304b7e6c644aae5c03fe1be93b77ee81e 100644 --- a/Plotting.py +++ b/Plotting.py @@ -3,6 +3,7 @@ @author: Laura C. Kühle TODO: Give option to select plotting color +TODO: Improve classification plotting """ import numpy as np @@ -235,7 +236,7 @@ def calculate_exact_solution(mesh, cell_len, wave_speed, final_time, interval_le return grid, exact -def plot_classification_accuracy(xlabels, precision, recall, accuracy): +def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, auroc): """Plots classification accuracy. Plots the accuracy, precision, and recall in a bar plot. @@ -255,13 +256,17 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy): precision = [precision] recall = [recall] accuracy = [accuracy] + fscore = [fscore] + auroc = [auroc] pos = np.arange(len(xlabels)) width = 0.3 fig = plt.figure('classification_accuracy') ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) + ax.bar(pos - 2*width, fscore, width, label='F-Score') ax.bar(pos - width, precision, width, label='Precision') ax.bar(pos, recall, width, label='Recall') ax.bar(pos + width, accuracy, width, label='Accuracy') + ax.bar(pos + 2*width, auroc, width, label='AUROC') ax.set_xticks(pos) ax.set_xticklabels(xlabels) ax.set_ylabel('Classification (%)')