Skip to content
Snippets Groups Projects
Commit 7209f885 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Implemented 'get_cells()' for 'ArtificialNeuralNetwork'.

parent 63f5eca3
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ TODO: Adapt 'train()' to fit style
TODO: Add ANN testing from Soraya to ANN implementation
TODO: Add ANN classification from Soraya to ANN implementation
TODO: Move ANN implementation to new file -> Done (Artificial_Neural_Network)
TODO: Add ANN detection from Soraya to ANN
TODO: Implement 'get_cells()' for 'ArtificialNeuralNetwork' -> Done
TODO: Adapt calculate_approximate_solution() to not require a quadrature -> Done
TODO: Add function to determine cell average and reconstructions -> Done
TODO: Fix cell averages and reconstructions to create data with an x-point stencil
......@@ -16,6 +16,7 @@ import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sympy import Symbol
x = Symbol('x')
......@@ -62,8 +63,8 @@ class TroubledCellDetector(object):
cell_averages = self._calculate_approximate_solution(projection, [0], 0)
left_reconstructions = self._calculate_approximate_solution(projection, [-1], self._polynomial_degree)
right_reconstructions = self._calculate_approximate_solution(projection, [1], self._polynomial_degree)
return np.array(list(zip(cell_averages[:, 0], left_reconstructions[:, 1], cell_averages[:, 1],
right_reconstructions[:, 1], cell_averages[:, 2])))
return np.array(list(map(np.float64, zip(cell_averages[:, 0], left_reconstructions[:, 1], cell_averages[:, 1],
right_reconstructions[:, 1], cell_averages[:, 2]))))
def plot_results(self, projection, troubled_cell_history, time_history):
self._plot_shock_tube(troubled_cell_history, time_history)
......@@ -201,6 +202,9 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
def _reset(self, config):
super()._reset(config)
self._stencil_len = config.pop('stencil_len', 3)
self._model = config.pop('model')
self._model_state = config.pop('model_state', 'Train24k24k_Valid8k8k_Norm12ReLU10nodesAdamlr1e-2MSE.pt')
# training_dir = config.pop('data_dir', 'data')
# training_file = config.pop('training_set', 'smooth_0.01k__troubled_0.01k__normalized.npy')
# validation_file = config.pop('validation_set', 'smooth_0.01k__troubled_0.01k__normalized.npy')
......@@ -209,7 +213,21 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
# self._training_data = {'train': [], 'validation': [], 'test': []}
def get_cells(self, projection):
pass
num_ghost_cells = self._stencil_len//2
projection = projection[:, 1:-1]
# projection = projection[:, :5]
projection = np.concatenate((projection[:, -num_ghost_cells:], projection, projection[:, :num_ghost_cells]),
axis=1)
input_data = torch.from_numpy(np.vstack([self.calculate_cell_average_and_reconstructions(
projection[:, cell-num_ghost_cells:cell+num_ghost_cells+1])
for cell in range(num_ghost_cells, len(projection[0])-num_ghost_cells)]))
self._model.load_state_dict(torch.load(self._model_state))
self._model.eval()
model_output = torch.round(self._model(input_data.float()))
return [cell for cell in range(len(model_output)) if model_output[cell, 0] == torch.tensor([1])]
class WaveletDetector(TroubledCellDetector):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment