diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index d92273164a61039827f4841f8ee7c482b44107c8..ebc71d2c80bb79e86b07697c961055a38f413fd3 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -3,6 +3,7 @@ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) TODO: Fix cell averages and reconstructions to create data with an x-point stencil +TODO: Add comments to get_cells() for ArtificialNeuralNetwork -> Done """ import os @@ -210,19 +211,22 @@ class ArtificialNeuralNetwork(TroubledCellDetector): self._model = getattr(ANN_Model, self._model)(self._model_config) def get_cells(self, projection): + # Reset ghost cells to adjust for stencil length num_ghost_cells = self._stencil_len//2 - projection = projection[:, 1:-1] - # projection = projection[:, :5] projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]), axis=1) + + # Calculate input data depending on stencil length input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions( projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1]) for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)])) + # Evaluate troubled cell probabilities self._model.load_state_dict(torch.load(self._model_state)) self._model.eval() + # Return troubled cells model_output = torch.round(self._model(input_data.float())) return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])]