diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index 4f753dd53ae456b10929797c08a3a8e5d2af3e89..acf0d9535462909648045dc09893b0251b6bb79e 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -22,50 +22,27 @@ class TrainingDataGenerator: Generates random training data for given initial conditions. - Attributes - ---------- - smooth_functions : list - List of smooth initial/continuous conditions. - troubled_functions : list - List of discontinuous initial conditions. - data_dir : str - Path to directory in which training data is saved. - Methods ------- build_training_data(num_samples) Builds random training data. """ - def __init__(self, initial_conditions, left_bound=-1, right_bound=1, - balance=0.5, stencil_length=3, directory='test_data', - add_reconstructions=True): + def __init__(self, left_bound=-1, right_bound=1, stencil_length=3): """Initializes TrainingDataGenerator. Parameters ---------- - initial_conditions : list - List of names of initial conditions for training. left_bound : float, optional Left boundary of interval. Default: -1. right_bound : float, optional Right boundary of interval. Default: 1. - balance: float, optional - Ratio between smooth and discontinuous training data. Default: 0.5. stencil_length : int, optional Size of training data array. Default: 3. - directory : str, optional - Path to directory in which training data is saved. - Default: 'test_data'. - add_reconstructions: bool, optional - Flag whether reconstructions of the middle cell are included. - Default: True. """ - self._balance = balance self._left_bound = left_bound self._right_bound = right_bound - self._add_reconstructions = add_reconstructions # Set stencil length if stencil_length % 2 == 0: @@ -73,29 +50,26 @@ class TrainingDataGenerator: % stencil_length) self._stencil_length = stencil_length - # Separate smooth and discontinuous initial conditions - self._smooth_functions = [] - self._troubled_functions = [] - for function in initial_conditions: - if function['function'].is_smooth(): - self._smooth_functions.append(function) - else: - self._troubled_functions.append(function) - - # Set directory - self._data_dir = directory - if not os.path.exists(self._data_dir): - os.makedirs(self._data_dir) - - def build_training_data(self, num_samples): + def build_training_data(self, initial_conditions, num_samples, balance=0.5, + directory='test_data', add_reconstructions=True): """Builds random training data. Creates training data consisting of random ANN input and saves it. Parameters ---------- + initial_conditions : list + List of names of initial conditions for training. num_samples : int Number of training data samples to generate. + balance : float, optional + Ratio between smooth and discontinuous training data. Default: 0.5. + directory : str, optional + Path to directory in which training data is saved. + Default: 'test_data'. + add_reconstructions : bool, optional + Flag whether reconstructions of the middle cell are included. + Default: True. Returns ------- @@ -106,15 +80,18 @@ class TrainingDataGenerator: """ tic = time.perf_counter() print('Calculating training data...\n') - data_dict = self._calculate_data_set(num_samples) + data_dict = self._calculate_data_set(initial_conditions, + num_samples, balance, + add_reconstructions) print('Finished calculating training data!') - self._save_data(data_dict) + self._save_data(directory=directory, data=data_dict) toc = time.perf_counter() print(f'Total runtime: {toc - tic:0.4f}s') return data_dict - def _calculate_data_set(self, num_samples): + def _calculate_data_set(self, initial_conditions, num_samples, balance, + add_reconstructions): """Calculates random training data of given stencil length. Creates training data with a given ratio between smooth and @@ -122,8 +99,14 @@ class TrainingDataGenerator: Parameters ---------- + initial_conditions : list + List of names of initial conditions for training. num_samples : int Number of training data samples to generate. + balance : float + Ratio between smooth and discontinuous training data. + add_reconstructions : bool + Flag whether reconstructions of the middle cell are included. Returns ------- @@ -132,13 +115,24 @@ class TrainingDataGenerator: output data. """ - num_smooth_samples = round(num_samples * self._balance) + # print(type(initial_conditions)) + # Separate smooth and discontinuous initial conditions + smooth_functions = [] + troubled_functions = [] + for function in initial_conditions: + if function['function'].is_smooth(): + smooth_functions.append(function) + else: + troubled_functions.append(function) + + num_smooth_samples = round(num_samples * balance) smooth_input, smooth_output = self._generate_cell_data( - num_smooth_samples, self._smooth_functions, True) + num_smooth_samples, smooth_functions, add_reconstructions, True) num_troubled_samples = num_samples - num_smooth_samples troubled_input, troubled_output = self._generate_cell_data( - num_troubled_samples, self._troubled_functions, False) + num_troubled_samples, troubled_functions, add_reconstructions, + False) # Merge Data input_matrix = np.concatenate((smooth_input, troubled_input), axis=0) @@ -157,7 +151,8 @@ class TrainingDataGenerator: return {'input_data.raw': input_matrix, 'output_data': output_matrix, 'input_data.normalized': norm_input_matrix} - def _generate_cell_data(self, num_samples, initial_conditions, is_smooth): + def _generate_cell_data(self, num_samples, initial_conditions, + add_reconstructions, is_smooth): """Generates random training input and output. Generates random training input and output for either smooth or @@ -170,6 +165,8 @@ class TrainingDataGenerator: Number of training data samples to generate. initial_conditions : list List of names of initial conditions for training. + add_reconstructions : bool + Flag whether reconstructions of the middle cell are included. is_smooth : bool Flag whether initial conditions are smooth. @@ -181,13 +178,14 @@ class TrainingDataGenerator: Array containing output data. """ + # print(type(initial_conditions)) troubled_indicator = 'without' if is_smooth else 'with' print('Calculating data ' + troubled_indicator + ' troubled cells...') print('Samples to complete:', num_samples) tic = time.perf_counter() num_datapoints = self._stencil_length - if self._add_reconstructions: + if add_reconstructions: num_datapoints += 2 input_data = np.zeros((num_samples, num_datapoints)) num_init_cond = len(initial_conditions) @@ -212,19 +210,19 @@ class TrainingDataGenerator: # Calculate basis coefficients for stencil polynomial_degree = np.random.randint(1, high=5) - basis = basis_list[polynomial_degree] mesh = Mesh(num_grid_cells=self._stencil_length, num_ghost_cells=2, left_bound=left_bound, right_bound=right_bound) projection = do_initial_projection( initial_condition=initial_condition, mesh=mesh, - basis=basis, + basis=basis_list[polynomial_degree], quadrature=quadrature_list[polynomial_degree], adjustment=adjustment) - input_data[i] = basis.calculate_cell_average( + input_data[i] = basis_list[ + polynomial_degree].calculate_cell_average( projection=projection[:, 1:-1], stencil_length=self._stencil_length, - add_reconstructions=self._add_reconstructions) + add_reconstructions=add_reconstructions) count += 1 if count % 1000 == 0: @@ -298,9 +296,14 @@ class TrainingDataGenerator: normalized_input_data.append(entry / max_function_value) return np.array(normalized_input_data) - def _save_data(self, data): + @staticmethod + def _save_data(directory, data): """Saves data.""" + # Set directory + if not os.path.exists(directory): + os.makedirs(directory) + print('Saving training data.') for key in data.keys(): - name = self._data_dir + '/' + key + '.npy' + name = directory + '/' + key + '.npy' np.save(name, data[key]) diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk index df966e1354d4b253d156514afa2c382c71bbd529..e4a33c101d522460a3f192fdf6a658ce59292db3 100644 --- a/workflows/ANN_data.smk +++ b/workflows/ANN_data.smk @@ -38,10 +38,9 @@ rule generate_data: with open(str(log), 'w') as logfile: sys.stdout = logfile generator = ANN_Data_Generator.TrainingDataGenerator( - initial_conditions=initial_conditions, left_bound=params.left_bound, right_bound=params.right_bound, - balance=params.balance, - stencil_length=params.stencil_length, directory=DIR, - add_reconstructions=params.reconstruction_flag) - data = generator.build_training_data( - num_samples=params.sample_number) \ No newline at end of file + stencil_length=params.stencil_length) + data = generator.build_training_data(balance=params.balance, + initial_conditions=initial_conditions, directory=DIR, + num_samples=params.sample_number, + add_reconstructions=params.reconstruction_flag) \ No newline at end of file