diff --git a/ANN_Training.py b/ANN_Training.py
index 0f3afb7c3c0cb81e874a3c17d3dc923f4690e2dc..191efef5dec6b6c52abbf53206a0f8d899deb7e5 100644
--- a/ANN_Training.py
+++ b/ANN_Training.py
@@ -8,6 +8,9 @@ TODO: Add log to pipeline
 TODO: Remove object set-up
 TODO: Optimize Snakefile-vs-config relation
 TODO: Improve maximum selection runtime
+TODO: Discuss if we want training accuracy/ROC in addition to CFV
+TODO: Discuss whether to change output to binary
+TODO: Adapt TCD file to new classification
 
 """
 import numpy as np
@@ -17,7 +20,7 @@ import torch
 from torch.utils.data import TensorDataset, DataLoader, random_split
 from sklearn.model_selection import KFold
 # from sklearn.metrics import accuracy_score
-from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score
+from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_score, roc_auc_score, roc_curve
 
 import ANN_Model
 from Plotting import plot_classification_accuracy
@@ -142,18 +145,28 @@ class ModelTrainer(object):
 
         x_test, y_test = test_set
         # print(self._model(x_test.float()))
+        model_score = self._model(x_test.float())
         model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
-                                     for value in torch.max(self._model(x_test.float()), 1)[1]])
+                                     for value in torch.max(model_score, 1)[1]])
 
-        y_true = y_test.detach().numpy()
-        y_pred = model_output.detach().numpy()
+        y_true = y_test.detach().numpy()[:, 0]
+        y_pred = model_output.detach().numpy()[:, 0]
+        # y_score = model_score.detach().numpy()[:, 0]
         accuracy = accuracy_score(y_true, y_pred)
         # print('sklearn', accuracy)
         precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
-        # print(precision, recall)
+        # print(precision, recall, f_score)
         # print()
-
-        return [precision[0], recall[0], accuracy]
+        # auroc = roc_auc_score(y_true, y_score)
+        # print('auroc raw', auroc)
+        auroc = roc_auc_score(y_true, y_pred)
+        print('auroc true', auroc)
+        # fpr, tpr, thresholds = roc_curve(y_true, y_score)
+        # roc = [tpr, fpr, thresholds]
+        # print(roc)
+        # plt.plot(fpr, tpr, label="AUC="+str(auroc))
+
+        return [precision[0], recall[0], accuracy, f_score[0], auroc]
 
     def save_model(self):
         # Saving Model
diff --git a/Plotting.py b/Plotting.py
index 77c3bcdb028a7b68effd4d08c1956b70690a65ee..54f8753304b7e6c644aae5c03fe1be93b77ee81e 100644
--- a/Plotting.py
+++ b/Plotting.py
@@ -3,6 +3,7 @@
 @author: Laura C. Kühle
 
 TODO: Give option to select plotting color
+TODO: Improve classification plotting
 
 """
 import numpy as np
@@ -235,7 +236,7 @@ def calculate_exact_solution(mesh, cell_len, wave_speed, final_time, interval_le
     return grid, exact
 
 
-def plot_classification_accuracy(xlabels, precision, recall, accuracy):
+def plot_classification_accuracy(xlabels, precision, recall, accuracy, fscore, auroc):
     """Plots classification accuracy.
 
     Plots the accuracy, precision, and recall in a bar plot.
@@ -255,13 +256,17 @@ def plot_classification_accuracy(xlabels, precision, recall, accuracy):
     precision = [precision]
     recall = [recall]
     accuracy = [accuracy]
+    fscore = [fscore]
+    auroc = [auroc]
     pos = np.arange(len(xlabels))
     width = 0.3
     fig = plt.figure('classification_accuracy')
     ax = fig.add_axes([0.15, 0.1, 0.75, 0.8])
+    ax.bar(pos - 2*width, fscore, width, label='F-Score')
     ax.bar(pos - width, precision, width, label='Precision')
     ax.bar(pos, recall, width, label='Recall')
     ax.bar(pos + width, accuracy, width, label='Accuracy')
+    ax.bar(pos + 2*width, auroc, width, label='AUROC')
     ax.set_xticks(pos)
     ax.set_xticklabels(xlabels)
     ax.set_ylabel('Classification (%)')