Skip to content
Snippets Groups Projects
Commit 38895885 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Removed 'stencil_length' as instance variable.

parent b7f28556
Branches
No related tags found
No related merge requests found
...@@ -33,7 +33,7 @@ class TrainingDataGenerator: ...@@ -33,7 +33,7 @@ class TrainingDataGenerator:
Builds random training data. Builds random training data.
""" """
def __init__(self, left_bound=-1, right_bound=1, stencil_length=3): def __init__(self, left_bound=-1, right_bound=1):
"""Initializes TrainingDataGenerator. """Initializes TrainingDataGenerator.
Parameters Parameters
...@@ -42,8 +42,6 @@ class TrainingDataGenerator: ...@@ -42,8 +42,6 @@ class TrainingDataGenerator:
Left boundary of interval. Default: -1. Left boundary of interval. Default: -1.
right_bound : float, optional right_bound : float, optional
Right boundary of interval. Default: 1. Right boundary of interval. Default: 1.
stencil_length : int, optional
Size of training data array. Default: 3.
""" """
self._basis_list = [OrthonormalLegendre(pol_deg) self._basis_list = [OrthonormalLegendre(pol_deg)
...@@ -54,14 +52,9 @@ class TrainingDataGenerator: ...@@ -54,14 +52,9 @@ class TrainingDataGenerator:
num_ghost_cells=0, num_grid_cells=2**exp) num_ghost_cells=0, num_grid_cells=2**exp)
for exp in range(3, 9)] for exp in range(3, 9)]
# Set stencil length
if stencil_length % 2 == 0:
raise ValueError('Invalid stencil length (even value): "%d"'
% stencil_length)
self._stencil_length = stencil_length
def build_training_data(self, initial_conditions, num_samples, balance=0.5, def build_training_data(self, initial_conditions, num_samples, balance=0.5,
directory='test_data', add_reconstructions=True): directory='test_data', add_reconstructions=True,
stencil_length=3):
"""Builds random training data. """Builds random training data.
Creates training data consisting of random ANN input and saves it. Creates training data consisting of random ANN input and saves it.
...@@ -80,6 +73,8 @@ class TrainingDataGenerator: ...@@ -80,6 +73,8 @@ class TrainingDataGenerator:
add_reconstructions : bool, optional add_reconstructions : bool, optional
Flag whether reconstructions of the middle cell are included. Flag whether reconstructions of the middle cell are included.
Default: True. Default: True.
stencil_length : int, optional
Size of training data array. Default: 3.
Returns Returns
------- -------
...@@ -89,10 +84,17 @@ class TrainingDataGenerator: ...@@ -89,10 +84,17 @@ class TrainingDataGenerator:
""" """
tic = time.perf_counter() tic = time.perf_counter()
# Set stencil length
if stencil_length % 2 == 0:
raise ValueError('Invalid stencil length (even value): "%d"'
% stencil_length)
print('Calculating training data...\n') print('Calculating training data...\n')
data_dict = self._calculate_data_set(initial_conditions, data_dict = self._calculate_data_set(initial_conditions,
num_samples, balance, num_samples, balance,
add_reconstructions) add_reconstructions,
stencil_length)
print('Finished calculating training data!') print('Finished calculating training data!')
self._save_data(directory=directory, data=data_dict) self._save_data(directory=directory, data=data_dict)
...@@ -101,7 +103,7 @@ class TrainingDataGenerator: ...@@ -101,7 +103,7 @@ class TrainingDataGenerator:
return data_dict return data_dict
def _calculate_data_set(self, initial_conditions, num_samples, balance, def _calculate_data_set(self, initial_conditions, num_samples, balance,
add_reconstructions): add_reconstructions, stencil_length):
"""Calculates random training data of given stencil length. """Calculates random training data of given stencil length.
Creates training data with a given ratio between smooth and Creates training data with a given ratio between smooth and
...@@ -117,6 +119,8 @@ class TrainingDataGenerator: ...@@ -117,6 +119,8 @@ class TrainingDataGenerator:
Ratio between smooth and discontinuous training data. Ratio between smooth and discontinuous training data.
add_reconstructions : bool add_reconstructions : bool
Flag whether reconstructions of the middle cell are included. Flag whether reconstructions of the middle cell are included.
stencil_length : int
Size of training data array.
Returns Returns
------- -------
...@@ -137,12 +141,13 @@ class TrainingDataGenerator: ...@@ -137,12 +141,13 @@ class TrainingDataGenerator:
num_smooth_samples = round(num_samples * balance) num_smooth_samples = round(num_samples * balance)
smooth_input, smooth_output = self._generate_cell_data( smooth_input, smooth_output = self._generate_cell_data(
num_smooth_samples, smooth_functions, add_reconstructions, True) num_smooth_samples, smooth_functions, add_reconstructions,
stencil_length, True)
num_troubled_samples = num_samples - num_smooth_samples num_troubled_samples = num_samples - num_smooth_samples
troubled_input, troubled_output = self._generate_cell_data( troubled_input, troubled_output = self._generate_cell_data(
num_troubled_samples, troubled_functions, add_reconstructions, num_troubled_samples, troubled_functions, add_reconstructions,
False) stencil_length, False)
# Merge Data # Merge Data
input_matrix = np.concatenate((smooth_input, troubled_input), axis=0) input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
...@@ -162,7 +167,7 @@ class TrainingDataGenerator: ...@@ -162,7 +167,7 @@ class TrainingDataGenerator:
'input_data.normalized': norm_input_matrix} 'input_data.normalized': norm_input_matrix}
def _generate_cell_data(self, num_samples, initial_conditions, def _generate_cell_data(self, num_samples, initial_conditions,
add_reconstructions, is_smooth): add_reconstructions, stencil_length, is_smooth):
"""Generates random training input and output. """Generates random training input and output.
Generates random training input and output for either smooth or Generates random training input and output for either smooth or
...@@ -177,6 +182,8 @@ class TrainingDataGenerator: ...@@ -177,6 +182,8 @@ class TrainingDataGenerator:
List of names of initial conditions for training. List of names of initial conditions for training.
add_reconstructions : bool add_reconstructions : bool
Flag whether reconstructions of the middle cell are included. Flag whether reconstructions of the middle cell are included.
stencil_length : int
Size of training data array.
is_smooth : bool is_smooth : bool
Flag whether initial conditions are smooth. Flag whether initial conditions are smooth.
...@@ -194,7 +201,7 @@ class TrainingDataGenerator: ...@@ -194,7 +201,7 @@ class TrainingDataGenerator:
print('Samples to complete:', num_samples) print('Samples to complete:', num_samples)
tic = time.perf_counter() tic = time.perf_counter()
num_datapoints = self._stencil_length num_datapoints = stencil_length
if add_reconstructions: if add_reconstructions:
num_datapoints += 2 num_datapoints += 2
input_data = np.zeros((num_samples, num_datapoints)) input_data = np.zeros((num_samples, num_datapoints))
...@@ -209,11 +216,11 @@ class TrainingDataGenerator: ...@@ -209,11 +216,11 @@ class TrainingDataGenerator:
# Build mesh for random stencil of given length # Build mesh for random stencil of given length
mesh = self._mesh_list[int(np.random.randint( mesh = self._mesh_list[int(np.random.randint(
3, high=9, size=1))-3].random_stencil(self._stencil_length) 3, high=9, size=1))-3].random_stencil(stencil_length)
# Induce adjustment to capture troubled cells # Induce adjustment to capture troubled cells
adjustment = 0 if initial_condition.is_smooth() \ adjustment = 0 if initial_condition.is_smooth() \
else mesh.non_ghost_cells[self._stencil_length//2] else mesh.non_ghost_cells[stencil_length//2]
initial_condition.induce_adjustment(-mesh.cell_len/3) initial_condition.induce_adjustment(-mesh.cell_len/3)
# Calculate basis coefficients for stencil # Calculate basis coefficients for stencil
...@@ -226,7 +233,7 @@ class TrainingDataGenerator: ...@@ -226,7 +233,7 @@ class TrainingDataGenerator:
input_data[i] = self._basis_list[ input_data[i] = self._basis_list[
polynomial_degree].calculate_cell_average( polynomial_degree].calculate_cell_average(
projection=projection[:, 1:-1], projection=projection[:, 1:-1],
stencil_length=self._stencil_length, stencil_length=stencil_length,
add_reconstructions=add_reconstructions) add_reconstructions=add_reconstructions)
count += 1 count += 1
......
...@@ -38,9 +38,9 @@ rule generate_data: ...@@ -38,9 +38,9 @@ rule generate_data:
with open(str(log), 'w') as logfile: with open(str(log), 'w') as logfile:
sys.stdout = logfile sys.stdout = logfile
generator = ANN_Data_Generator.TrainingDataGenerator( generator = ANN_Data_Generator.TrainingDataGenerator(
left_bound=params.left_bound, right_bound=params.right_bound, left_bound=params.left_bound, right_bound=params.right_bound)
stencil_length=params.stencil_length)
data = generator.build_training_data(balance=params.balance, data = generator.build_training_data(balance=params.balance,
initial_conditions=initial_conditions, directory=DIR, initial_conditions=initial_conditions, directory=DIR,
num_samples=params.sample_number, num_samples=params.sample_number,
add_reconstructions=params.reconstruction_flag) add_reconstructions=params.reconstruction_flag,
\ No newline at end of file stencil_length=params.stencil_length)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment