diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index b7600bbf24f304a395831fea42a7f7b6a3579a00..c5f07896b924a2dc6fefb95a9189c8a2540e473a 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -33,7 +33,7 @@ class TrainingDataGenerator: Builds random training data. """ - def __init__(self, left_bound=-1, right_bound=1, stencil_length=3): + def __init__(self, left_bound=-1, right_bound=1): """Initializes TrainingDataGenerator. Parameters @@ -42,8 +42,6 @@ class TrainingDataGenerator: Left boundary of interval. Default: -1. right_bound : float, optional Right boundary of interval. Default: 1. - stencil_length : int, optional - Size of training data array. Default: 3. """ self._basis_list = [OrthonormalLegendre(pol_deg) @@ -54,14 +52,9 @@ class TrainingDataGenerator: num_ghost_cells=0, num_grid_cells=2**exp) for exp in range(3, 9)] - # Set stencil length - if stencil_length % 2 == 0: - raise ValueError('Invalid stencil length (even value): "%d"' - % stencil_length) - self._stencil_length = stencil_length - def build_training_data(self, initial_conditions, num_samples, balance=0.5, - directory='test_data', add_reconstructions=True): + directory='test_data', add_reconstructions=True, + stencil_length=3): """Builds random training data. Creates training data consisting of random ANN input and saves it. @@ -80,6 +73,8 @@ class TrainingDataGenerator: add_reconstructions : bool, optional Flag whether reconstructions of the middle cell are included. Default: True. + stencil_length : int, optional + Size of training data array. Default: 3. Returns ------- @@ -89,10 +84,17 @@ class TrainingDataGenerator: """ tic = time.perf_counter() + + # Set stencil length + if stencil_length % 2 == 0: + raise ValueError('Invalid stencil length (even value): "%d"' + % stencil_length) + print('Calculating training data...\n') data_dict = self._calculate_data_set(initial_conditions, num_samples, balance, - add_reconstructions) + add_reconstructions, + stencil_length) print('Finished calculating training data!') self._save_data(directory=directory, data=data_dict) @@ -101,7 +103,7 @@ class TrainingDataGenerator: return data_dict def _calculate_data_set(self, initial_conditions, num_samples, balance, - add_reconstructions): + add_reconstructions, stencil_length): """Calculates random training data of given stencil length. Creates training data with a given ratio between smooth and @@ -117,6 +119,8 @@ class TrainingDataGenerator: Ratio between smooth and discontinuous training data. add_reconstructions : bool Flag whether reconstructions of the middle cell are included. + stencil_length : int + Size of training data array. Returns ------- @@ -137,12 +141,13 @@ class TrainingDataGenerator: num_smooth_samples = round(num_samples * balance) smooth_input, smooth_output = self._generate_cell_data( - num_smooth_samples, smooth_functions, add_reconstructions, True) + num_smooth_samples, smooth_functions, add_reconstructions, + stencil_length, True) num_troubled_samples = num_samples - num_smooth_samples troubled_input, troubled_output = self._generate_cell_data( num_troubled_samples, troubled_functions, add_reconstructions, - False) + stencil_length, False) # Merge Data input_matrix = np.concatenate((smooth_input, troubled_input), axis=0) @@ -162,7 +167,7 @@ class TrainingDataGenerator: 'input_data.normalized': norm_input_matrix} def _generate_cell_data(self, num_samples, initial_conditions, - add_reconstructions, is_smooth): + add_reconstructions, stencil_length, is_smooth): """Generates random training input and output. Generates random training input and output for either smooth or @@ -177,6 +182,8 @@ class TrainingDataGenerator: List of names of initial conditions for training. add_reconstructions : bool Flag whether reconstructions of the middle cell are included. + stencil_length : int + Size of training data array. is_smooth : bool Flag whether initial conditions are smooth. @@ -194,7 +201,7 @@ class TrainingDataGenerator: print('Samples to complete:', num_samples) tic = time.perf_counter() - num_datapoints = self._stencil_length + num_datapoints = stencil_length if add_reconstructions: num_datapoints += 2 input_data = np.zeros((num_samples, num_datapoints)) @@ -209,11 +216,11 @@ class TrainingDataGenerator: # Build mesh for random stencil of given length mesh = self._mesh_list[int(np.random.randint( - 3, high=9, size=1))-3].random_stencil(self._stencil_length) + 3, high=9, size=1))-3].random_stencil(stencil_length) # Induce adjustment to capture troubled cells adjustment = 0 if initial_condition.is_smooth() \ - else mesh.non_ghost_cells[self._stencil_length//2] + else mesh.non_ghost_cells[stencil_length//2] initial_condition.induce_adjustment(-mesh.cell_len/3) # Calculate basis coefficients for stencil @@ -226,7 +233,7 @@ class TrainingDataGenerator: input_data[i] = self._basis_list[ polynomial_degree].calculate_cell_average( projection=projection[:, 1:-1], - stencil_length=self._stencil_length, + stencil_length=stencil_length, add_reconstructions=add_reconstructions) count += 1 diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk index e4a33c101d522460a3f192fdf6a658ce59292db3..1e2bf9061c9a729e8272be9d6248a6c739a77985 100644 --- a/workflows/ANN_data.smk +++ b/workflows/ANN_data.smk @@ -38,9 +38,9 @@ rule generate_data: with open(str(log), 'w') as logfile: sys.stdout = logfile generator = ANN_Data_Generator.TrainingDataGenerator( - left_bound=params.left_bound, right_bound=params.right_bound, - stencil_length=params.stencil_length) + left_bound=params.left_bound, right_bound=params.right_bound) data = generator.build_training_data(balance=params.balance, initial_conditions=initial_conditions, directory=DIR, num_samples=params.sample_number, - add_reconstructions=params.reconstruction_flag) \ No newline at end of file + add_reconstructions=params.reconstruction_flag, + stencil_length=params.stencil_length) \ No newline at end of file