From c8a41e549c02eab87c72ddde01fcb2b522933181 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?=
 <laura.kuehle@uni-duesseldorf.de>
Date: Wed, 19 Jan 2022 23:36:54 +0100
Subject: [PATCH] Adapted ANN input data so that model output equals 1 for a
 troubled cell.

---
 ANN_Data_Generator.py     | 5 ++---
 Troubled_Cell_Detector.py | 3 ++-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py
index 14dbed1..64b0c7f 100644
--- a/ANN_Data_Generator.py
+++ b/ANN_Data_Generator.py
@@ -143,7 +143,7 @@ class TrainingDataGenerator(object):
         """Generates random training input and output.
 
         Generates random training input and output for either smooth or discontinuous
-        initial_conditions.
+        initial conditions. For each input the output has the shape [is_smooth, is_troubled]
 
         Parameters
         ----------
@@ -203,9 +203,8 @@ class TrainingDataGenerator(object):
         print('Calculation time:', toc-tic, '\n')
 
         # Set output data
-        # output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples)
         output_data = np.zeros((num_samples, 2))
-        output_data[:, int(is_smooth)] = np.ones(num_samples)
+        output_data[:, int(not is_smooth)] = np.ones(num_samples)
 
         return input_data, output_data
 
diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py
index d3dabe3..45fcd42 100644
--- a/Troubled_Cell_Detector.py
+++ b/Troubled_Cell_Detector.py
@@ -7,6 +7,7 @@ TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
 TODO: Give detailed description of wavelet detection
 TODO: Load ANN state and config in reset -> Done
 TODO: Fix bug in output calculation -> Done
+TODO: Adapt input data so that model output equals 1 for a troubled cell -> Done
 
 """
 import numpy as np
@@ -306,7 +307,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
         # Determine troubled cells
         model_output = torch.argmax(self._model(input_data.float()), dim=1)
         return [cell for cell in range(len(model_output))
-                if model_output[cell] == torch.tensor([0])]
+                if model_output[cell] == torch.tensor([1])]
 
 
 class WaveletDetector(TroubledCellDetector):
-- 
GitLab