Skip to content
Snippets Groups Projects
Commit 2c9ebeb3 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Changed code to only allow models from ANN_Model.

parent 9246734a
Branches
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
TODO: Fix cell averages and reconstructions to create data with an x-point stencil
TODO: Only allow model from 'ANN_Model' -> Done
"""
import os
......@@ -12,6 +13,8 @@ import seaborn as sns
import torch
from sympy import Symbol
import ANN_Model
x = Symbol('x')
z = Symbol('z')
......@@ -196,14 +199,13 @@ class ArtificialNeuralNetwork(TroubledCellDetector):
super()._reset(config)
self._stencil_len = config.pop('stencil_len', 3)
self._model = config.pop('model')
self._model_state = config.pop('model_state', 'Train24k24k_Valid8k8k_Norm12ReLU10nodesAdamlr1e-2MSE.pt')
# training_dir = config.pop('data_dir', 'data')
# training_file = config.pop('training_set', 'smooth_0.01k__troubled_0.01k__normalized.npy')
# validation_file = config.pop('validation_set', 'smooth_0.01k__troubled_0.01k__normalized.npy')
# test_file = config.pop('test_set', 'smooth_0.01k__troubled_0.01k__normalized.npy')
# batch_size = config.pop('batch_size', 500)
# self._training_data = {'train': [], 'validation': [], 'test': []}
self._model = config.pop('model', 'ThreeLayerNetDifferentNeuronsSoftMax')
self._model_config = config.pop('model_config', {'d_in': self._stencil_len+2, 'h1': 8, 'h2': 4, 'd_out': 2})
self._model_state = config.pop('model_state', 'Train24k24k_Valid8k8k_Norm12ReLU8+4nodesSM1Adamlr1e-2MSE.pt')
if not hasattr(ANN_Model, self._model):
raise ValueError('Invalid model: "%s"' % self._model)
self._model = getattr(ANN_Model, self._model)(**self._model_config)
def get_cells(self, projection):
num_ghost_cells = self._stencil_len//2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment