diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index 6d16dbde71a295411fb920e6867c967a07aa8e76..45f96137c65585f8cd197e43b68818acd214dcfa 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -4,6 +4,7 @@ TODO: Improve '_generate_cell_data' TODO: Extract normalization (Combine smooth and troubled before normalizing) -> Done +TODO: Adapt code to generate both normalized and non-normalized data -> Done TODO: Improve verbose output """ @@ -43,22 +44,20 @@ class TrainingDataGenerator(object): if not os.path.exists(self._data_dir): os.makedirs(self._data_dir) - def build_training_data(self, num_samples, normalize): + def build_training_data(self, num_samples): print('Calculating training data...') - input_data, output_data = self._calculate_data_set(num_samples, normalize) - data = [input_data, output_data] + data_dict = self._calculate_data_set(num_samples) print('Finished calculating training data!') - self._save_data(data, num_samples, normalize) - return data + self._save_data(data_dict) + return data_dict - def _save_data(self, data, num_samples, normalize): - input_name = self._data_dir + '/input_data.npy' - np.save(input_name, data[0]) - output_name = self._data_dir + '/output_data.npy' - np.save(output_name, data[1]) + def _save_data(self, data): + for key in data.keys(): + name = self._data_dir + '/' + key + '_data.npy' + np.save(name, data[key]) - def _calculate_data_set(self, num_samples, normalize): + def _calculate_data_set(self, num_samples): num_smooth_samples = round(num_samples * self._balance) smooth_input, smooth_output = self._generate_cell_data(num_smooth_samples, self._smooth_functions, True) @@ -77,10 +76,10 @@ class TrainingDataGenerator(object): output_matrix = output_matrix[order] # Create normalized input data - if normalize: - input_matrix = self._normalize_data(input_matrix) + norm_input_matrix = self._normalize_data(input_matrix) - return [input_matrix, output_matrix] + return {'input': input_matrix, 'output': output_matrix, + 'normalized_input': norm_input_matrix} def _generate_cell_data(self, num_samples, initial_conditions, is_smooth): num_function_samples = num_samples//len(initial_conditions) diff --git a/Snakefile b/Snakefile index c084f7f7fa24c9ff5bd8c61a3c6b92b6dd4a10f3..401616ffe3fc8007fe17bfad9994e59886bc46a1 100644 --- a/Snakefile +++ b/Snakefile @@ -63,7 +63,7 @@ rule generate_data: generator = ANN_Data_Generator.TrainingDataGenerator(initial_conditions, left_bound=params.left_bound, right_bound=params.right_bound, balance=params.balance, stencil_length=params.stencil_length, directory=DIR) - data = generator.build_training_data(params.sample_number, 1) + data = generator.build_training_data(params.sample_number) # print(data[0]) rule train_model: