Skip to content
Snippets Groups Projects
Commit 38895885 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Removed 'stencil_length' as instance variable.

parent b7f28556
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,7 @@ class TrainingDataGenerator:
Builds random training data.
"""
def __init__(self, left_bound=-1, right_bound=1, stencil_length=3):
def __init__(self, left_bound=-1, right_bound=1):
"""Initializes TrainingDataGenerator.
Parameters
......@@ -42,8 +42,6 @@ class TrainingDataGenerator:
Left boundary of interval. Default: -1.
right_bound : float, optional
Right boundary of interval. Default: 1.
stencil_length : int, optional
Size of training data array. Default: 3.
"""
self._basis_list = [OrthonormalLegendre(pol_deg)
......@@ -54,14 +52,9 @@ class TrainingDataGenerator:
num_ghost_cells=0, num_grid_cells=2**exp)
for exp in range(3, 9)]
# Set stencil length
if stencil_length % 2 == 0:
raise ValueError('Invalid stencil length (even value): "%d"'
% stencil_length)
self._stencil_length = stencil_length
def build_training_data(self, initial_conditions, num_samples, balance=0.5,
directory='test_data', add_reconstructions=True):
directory='test_data', add_reconstructions=True,
stencil_length=3):
"""Builds random training data.
Creates training data consisting of random ANN input and saves it.
......@@ -80,6 +73,8 @@ class TrainingDataGenerator:
add_reconstructions : bool, optional
Flag whether reconstructions of the middle cell are included.
Default: True.
stencil_length : int, optional
Size of training data array. Default: 3.
Returns
-------
......@@ -89,10 +84,17 @@ class TrainingDataGenerator:
"""
tic = time.perf_counter()
# Set stencil length
if stencil_length % 2 == 0:
raise ValueError('Invalid stencil length (even value): "%d"'
% stencil_length)
print('Calculating training data...\n')
data_dict = self._calculate_data_set(initial_conditions,
num_samples, balance,
add_reconstructions)
add_reconstructions,
stencil_length)
print('Finished calculating training data!')
self._save_data(directory=directory, data=data_dict)
......@@ -101,7 +103,7 @@ class TrainingDataGenerator:
return data_dict
def _calculate_data_set(self, initial_conditions, num_samples, balance,
add_reconstructions):
add_reconstructions, stencil_length):
"""Calculates random training data of given stencil length.
Creates training data with a given ratio between smooth and
......@@ -117,6 +119,8 @@ class TrainingDataGenerator:
Ratio between smooth and discontinuous training data.
add_reconstructions : bool
Flag whether reconstructions of the middle cell are included.
stencil_length : int
Size of training data array.
Returns
-------
......@@ -137,12 +141,13 @@ class TrainingDataGenerator:
num_smooth_samples = round(num_samples * balance)
smooth_input, smooth_output = self._generate_cell_data(
num_smooth_samples, smooth_functions, add_reconstructions, True)
num_smooth_samples, smooth_functions, add_reconstructions,
stencil_length, True)
num_troubled_samples = num_samples - num_smooth_samples
troubled_input, troubled_output = self._generate_cell_data(
num_troubled_samples, troubled_functions, add_reconstructions,
False)
stencil_length, False)
# Merge Data
input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
......@@ -162,7 +167,7 @@ class TrainingDataGenerator:
'input_data.normalized': norm_input_matrix}
def _generate_cell_data(self, num_samples, initial_conditions,
add_reconstructions, is_smooth):
add_reconstructions, stencil_length, is_smooth):
"""Generates random training input and output.
Generates random training input and output for either smooth or
......@@ -177,6 +182,8 @@ class TrainingDataGenerator:
List of names of initial conditions for training.
add_reconstructions : bool
Flag whether reconstructions of the middle cell are included.
stencil_length : int
Size of training data array.
is_smooth : bool
Flag whether initial conditions are smooth.
......@@ -194,7 +201,7 @@ class TrainingDataGenerator:
print('Samples to complete:', num_samples)
tic = time.perf_counter()
num_datapoints = self._stencil_length
num_datapoints = stencil_length
if add_reconstructions:
num_datapoints += 2
input_data = np.zeros((num_samples, num_datapoints))
......@@ -209,11 +216,11 @@ class TrainingDataGenerator:
# Build mesh for random stencil of given length
mesh = self._mesh_list[int(np.random.randint(
3, high=9, size=1))-3].random_stencil(self._stencil_length)
3, high=9, size=1))-3].random_stencil(stencil_length)
# Induce adjustment to capture troubled cells
adjustment = 0 if initial_condition.is_smooth() \
else mesh.non_ghost_cells[self._stencil_length//2]
else mesh.non_ghost_cells[stencil_length//2]
initial_condition.induce_adjustment(-mesh.cell_len/3)
# Calculate basis coefficients for stencil
......@@ -226,7 +233,7 @@ class TrainingDataGenerator:
input_data[i] = self._basis_list[
polynomial_degree].calculate_cell_average(
projection=projection[:, 1:-1],
stencil_length=self._stencil_length,
stencil_length=stencil_length,
add_reconstructions=add_reconstructions)
count += 1
......
......@@ -38,9 +38,9 @@ rule generate_data:
with open(str(log), 'w') as logfile:
sys.stdout = logfile
generator = ANN_Data_Generator.TrainingDataGenerator(
left_bound=params.left_bound, right_bound=params.right_bound,
stencil_length=params.stencil_length)
left_bound=params.left_bound, right_bound=params.right_bound)
data = generator.build_training_data(balance=params.balance,
initial_conditions=initial_conditions, directory=DIR,
num_samples=params.sample_number,
add_reconstructions=params.reconstruction_flag)
\ No newline at end of file
add_reconstructions=params.reconstruction_flag,
stencil_length=params.stencil_length)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment