Skip to content
Snippets Groups Projects
Commit c8a41e54 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Adapted ANN input data so that model output equals 1 for a troubled cell.

parent fd22e0ca
No related branches found
No related tags found
No related merge requests found
......@@ -143,7 +143,7 @@ class TrainingDataGenerator(object):
"""Generates random training input and output.
Generates random training input and output for either smooth or discontinuous
initial_conditions.
initial conditions. For each input the output has the shape [is_smooth, is_troubled]
Parameters
----------
......@@ -203,9 +203,8 @@ class TrainingDataGenerator(object):
print('Calculation time:', toc-tic, '\n')
# Set output data
# output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples)
output_data = np.zeros((num_samples, 2))
output_data[:, int(is_smooth)] = np.ones(num_samples)
output_data[:, int(not is_smooth)] = np.ones(num_samples)
return input_data, output_data
......
......@@ -7,6 +7,7 @@ TODO: Adjust Boxplot approach (adjacent cells, outer fence, etc.)
TODO: Give detailed description of wavelet detection
TODO: Load ANN state and config in reset -> Done
TODO: Fix bug in output calculation -> Done
TODO: Adapt input data so that model output equals 1 for a troubled cell -> Done
"""
import numpy as np
......@@ -306,7 +307,7 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
# Determine troubled cells
model_output = torch.argmax(self._model(input_data.float()), dim=1)
return [cell for cell in range(len(model_output))
if model_output[cell] == torch.tensor([0])]
if model_output[cell] == torch.tensor([1])]
class WaveletDetector(TroubledCellDetector):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment