diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index 14dbed12f0a1206b330666ab8fe2d8cddb5a7c3a..64b0c7f71f4d3fea2b04c58215f9b9fd9dac1ef2 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -143,7 +143,7 @@ class TrainingDataGenerator(object): """Generates random training input and output. Generates random training input and output for either smooth or discontinuous - initial_conditions. + initial conditions. For each input the output has the shape [is_smooth, is_troubled] Parameters ---------- @@ -203,9 +203,8 @@ class TrainingDataGenerator(object): print('Calculation time:', toc-tic, '\n') # Set output data - # output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples) output_data = np.zeros((num_samples, 2)) - output_data[:, int(is_smooth)] = np.ones(num_samples) + output_data[:, int(not is_smooth)] = np.ones(num_samples) return input_data, output_data diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index d3dabe3585e99d5952869d7d13439fe41418a43b..45fcd4238e5463dae39c5fb19c69d515a3d8272f 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -7,6 +7,7 @@ TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.) TODO: Give detailed description of wavelet detection TODO: Load ANN state and config in reset -> Done TODO: Fix bug in output calculation -> Done +TODO: Adapt input data so that model output equals 1 for a troubled cell -> Done """ import numpy as np @@ -306,7 +307,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector): # Determine troubled cells model_output = torch.argmax(self._model(input_data.float()), dim=1) return [cell for cell in range(len(model_output)) - if model_output[cell] == torch.tensor([0])] + if model_output[cell] == torch.tensor([1])] class WaveletDetector(TroubledCellDetector):