Skip to content
Snippets Groups Projects
Commit c8a41e54 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Adapted ANN input data so that model output equals 1 for a troubled cell.

parent fd22e0ca
Branches
No related tags found
No related merge requests found
...@@ -143,7 +143,7 @@ class TrainingDataGenerator(object): ...@@ -143,7 +143,7 @@ class TrainingDataGenerator(object):
"""Generates random training input and output. """Generates random training input and output.
Generates random training input and output for either smooth or discontinuous Generates random training input and output for either smooth or discontinuous
initial_conditions. initial conditions. For each input the output has the shape [is_smooth, is_troubled]
Parameters Parameters
---------- ----------
...@@ -203,9 +203,8 @@ class TrainingDataGenerator(object): ...@@ -203,9 +203,8 @@ class TrainingDataGenerator(object):
print('Calculation time:', toc-tic, '\n') print('Calculation time:', toc-tic, '\n')
# Set output data # Set output data
# output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples)
output_data = np.zeros((num_samples, 2)) output_data = np.zeros((num_samples, 2))
output_data[:, int(is_smooth)] = np.ones(num_samples) output_data[:, int(not is_smooth)] = np.ones(num_samples)
return input_data, output_data return input_data, output_data
......
...@@ -7,6 +7,7 @@ TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.) ...@@ -7,6 +7,7 @@ TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
TODO: Give detailed description of wavelet detection TODO: Give detailed description of wavelet detection
TODO: Load ANN state and config in reset -> Done TODO: Load ANN state and config in reset -> Done
TODO: Fix bug in output calculation -> Done TODO: Fix bug in output calculation -> Done
TODO: Adapt input data so that model output equals 1 for a troubled cell -> Done
""" """
import numpy as np import numpy as np
...@@ -306,7 +307,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector): ...@@ -306,7 +307,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
# Determine troubled cells # Determine troubled cells
model_output = torch.argmax(self._model(input_data.float()), dim=1) model_output = torch.argmax(self._model(input_data.float()), dim=1)
return [cell for cell in range(len(model_output)) return [cell for cell in range(len(model_output))
if model_output[cell] == torch.tensor([0])] if model_output[cell] == torch.tensor([1])]
class WaveletDetector(TroubledCellDetector): class WaveletDetector(TroubledCellDetector):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment