Skip to content
Snippets Groups Projects
Commit a557d401 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added explanatory comments to 'get_cells()' for the class 'ArtificialNeuralNetwork'.

parent 5a1746c6
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
TODO: Fix cell averages and reconstructions to create data with an x-point stencil
TODO: Add comments to get_cells() for ArtificialNeuralNetwork -> Done
"""
import os
......@@ -210,19 +211,22 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
self._model = getattr(ANN_Model, self._model)(self._model_config)
def get_cells(self, projection):
# Reset ghost cells to adjust for stencil length
num_ghost_cells = self._stencil_len//2
projection = projection[:, 1:-1]
# projection = projection[:, :5]
projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]),
axis=1)
# Calculate input data depending on stencil length
input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions(
projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1])
for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
# Evaluate troubled cell probabilities
self._model.load_state_dict(torch.load(self._model_state))
self._model.eval()
# Return troubled cells
model_output = torch.round(self._model(input_data.float()))
return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment