diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index d39452172f144bd841f4826d062e78c60ebad6ce..6d16dbde71a295411fb920e6867c967a07aa8e76 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -3,7 +3,7 @@ @author: Soraya Terrab (sorayaterrab), Laura C. Kühle TODO: Improve '_generate_cell_data' -TODO: Extract normalization (Combine smooth and troubled before normalizing) +TODO: Extract normalization (Combine smooth and troubled before normalizing) -> Done TODO: Improve verbose output """ @@ -67,11 +67,6 @@ class TrainingDataGenerator(object): troubled_input, troubled_output = self._generate_cell_data(num_troubled_samples, self._troubled_functions, False) - # Normalize data - if normalize: - smooth_input = self._normalize_data(smooth_input) - troubled_input = self._normalize_data(troubled_input) - # Merge Data input_matrix = np.concatenate((smooth_input, troubled_input), axis=0) output_matrix = np.concatenate((smooth_output, troubled_output), axis=0) @@ -81,7 +76,11 @@ class TrainingDataGenerator(object): input_matrix = input_matrix[order] output_matrix = output_matrix[order] - return input_matrix, output_matrix + # Create normalized input data + if normalize: + input_matrix = self._normalize_data(input_matrix) + + return [input_matrix, output_matrix] def _generate_cell_data(self, num_samples, initial_conditions, is_smooth): num_function_samples = num_samples//len(initial_conditions)