From cad015d0c1041e8e09bc6045a7cefae4ce4d4908 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Fri, 20 Aug 2021 17:14:15 +0200 Subject: [PATCH] Added ANN testing (based on Soraya's implementation). --- ANN_Training.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++--- Plotting.py | 24 ++++++++++++++++++ 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/ANN_Training.py b/ANN_Training.py index 80ac4a0..4fab8f5 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -3,20 +3,23 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Improve 'epoch_training()' -TODO: Add ANN testing from Soraya +TODO: Add ANN testing from Soraya -> Done TODO: Add ANN classification from Soraya TODO: Improve naming of training data/model (maybe different folders?) TODO: Adjust input file naming to fit training data -> Done TODO: Change code to add model directory if not existing -> Done TODO: Remove unnecessary comments -> Done +TODO: Add option to set plot directory """ import numpy as np import os import torch from torch.utils.data import TensorDataset, DataLoader +# from sklearn.metrics import accuracy_score, precision_recall_fscore_support import ANN_Model +from Plotting import plot_classification_accuracy class ModelTrainer(object): @@ -104,6 +107,62 @@ class ModelTrainer(object): if valid_loss / len(valid_dl) < self._threshold: break + def test_model(self): + self.epoch_training() + self._model.eval() + + x_test, y_test = self._training_data['test'] + model_output = torch.round(self._model(x_test.float())) + # acc = np.sum(model_output.numpy() == y_test.numpy()) + test_accuracy = (model_output == y_test).float().mean() + print(test_accuracy) + # print(model_output.nelement()) + # accuracy1 = torch.sum(torch.eq(model_output, y_test)).item() # /model_output.nelement() + # print(test_accuracy, accuracy1/model_output.nelement()) + # print(accuracy1) + + tp, fp, tn, fn = self._evaluate_classification(model_output, y_test) + precision, recall, accuracy = self._evaluate_stats(tp, fp, tn, fn) + # print(precision, recall) + # print(accuracy) + plot_classification_accuracy(precision, recall, accuracy, ['real']) + + @staticmethod + def _evaluate_classification(model_output, true_output): + # Positive being Discontinuous/Troubled Cells, Negative being Smooth/Good Cells + true_positive = 0 + true_negative = 0 + false_positive = 0 + false_negative = 0 + for i in range(true_output.size()[0]): + if model_output[i, 1] == model_output[i, 0]: + print(i, model_output[i]) + if true_output[i, 0] == torch.tensor([1]): + if model_output[i, 0] == true_output[i, 0]: + true_positive += 1 + else: + false_negative += 1 + if true_output[i, 1] == torch.tensor([1]): + if model_output[i, 1] == true_output[i, 1]: + true_negative += 1 + else: + false_positive += 1 + + return true_positive, true_negative, false_positive, false_negative + + @staticmethod + def _evaluate_stats(true_positive, true_negative, false_positive, false_negative): + if true_positive+false_positive == 0: + precision = 0 + recall = 0 + else: + precision = true_positive / (true_positive+false_positive) + recall = true_positive / (true_positive+false_negative) + + accuracy = (true_positive+true_negative) / (true_positive+true_negative+false_positive+false_negative) + # print(true_positive+true_negative+false_positive+false_negative) + return precision, recall, accuracy + def save_model(self): # Saving Model train_name = self._training_file.split('.npy')[0] @@ -124,6 +183,7 @@ class ModelTrainer(object): # Loss Functions: BCELoss, BCEWithLogitsLoss, CrossEntropyLoss (not working), MSELoss (with reduction='sum') # Optimizer: Adam, SGD -trainer = ModelTrainer({'loss_function': 'MSELoss', 'loss_config': {'reduction': 'sum'}}) -trainer.epoch_training() +trainer = ModelTrainer({'num_epochs': 100}) +# trainer.epoch_training() +trainer.test_model() trainer.save_model() diff --git a/Plotting.py b/Plotting.py index 5267856..2fa5fdb 100644 --- a/Plotting.py +++ b/Plotting.py @@ -2,6 +2,8 @@ """ @author: Laura C. Kühle +TODO: Give option to select plotting color + """ import numpy as np import matplotlib.pyplot as plt @@ -112,3 +114,25 @@ def calculate_exact_solution(mesh, cell_len, wave_speed, final_time, interval_le grid = np.reshape(np.array(grid), (1, len(grid) * len(grid[0]))) return grid, exact + + +def plot_classification_accuracy(precision, recall, accuracy, xlabels): + precision = [precision] + recall = [recall] + accuracy = [accuracy] + pos = np.arange(len(xlabels)) + width = 0.3 + fig = plt.figure('classification_accuracy') + ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) + ax.bar(pos - width, precision, width, label='Precision') + ax.bar(pos, recall, width, label='Recall') + ax.bar(pos + width, accuracy, width, label='Accuracy') + ax.set_xticks(x) + ax.set_xticklabels(xlabels) + ax.set_ylabel('Classification (%)') + ax.set_ylim(bottom=0.6) + ax.set_ylim(top=1.02) + ax.set_title('Non-Normalized Test Data') + ax.legend(loc='upper right') + # fig.tight_layout() + fig.savefig('TestAdamPrecisionRecallAccuracy.pdf') -- GitLab