Skip to content
Snippets Groups Projects
Commit 3ded44d2 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Adapted code to generate both normalized and non-normalized data.

parent 82c15f81
Branches
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@
TODO: Improve '_generate_cell_data'
TODO: Extract normalization (Combine smooth and troubled before normalizing) -> Done
TODO: Adapt code to generate both normalized and non-normalized data -> Done
TODO: Improve verbose output
"""
......@@ -43,22 +44,20 @@ class TrainingDataGenerator(object):
if not os.path.exists(self._data_dir):
os.makedirs(self._data_dir)
def build_training_data(self, num_samples, normalize):
def build_training_data(self, num_samples):
print('Calculating training data...')
input_data, output_data = self._calculate_data_set(num_samples, normalize)
data = [input_data, output_data]
data_dict = self._calculate_data_set(num_samples)
print('Finished calculating training data!')
self._save_data(data, num_samples, normalize)
return data
self._save_data(data_dict)
return data_dict
def _save_data(self, data, num_samples, normalize):
input_name = self._data_dir + '/input_data.npy'
np.save(input_name, data[0])
output_name = self._data_dir + '/output_data.npy'
np.save(output_name, data[1])
def _save_data(self, data):
for key in data.keys():
name = self._data_dir + '/' + key + '_data.npy'
np.save(name, data[key])
def _calculate_data_set(self, num_samples, normalize):
def _calculate_data_set(self, num_samples):
num_smooth_samples = round(num_samples * self._balance)
smooth_input, smooth_output = self._generate_cell_data(num_smooth_samples,
self._smooth_functions, True)
......@@ -77,10 +76,10 @@ class TrainingDataGenerator(object):
output_matrix = output_matrix[order]
# Create normalized input data
if normalize:
input_matrix = self._normalize_data(input_matrix)
norm_input_matrix = self._normalize_data(input_matrix)
return [input_matrix, output_matrix]
return {'input': input_matrix, 'output': output_matrix,
'normalized_input': norm_input_matrix}
def _generate_cell_data(self, num_samples, initial_conditions, is_smooth):
num_function_samples = num_samples//len(initial_conditions)
......
......@@ -63,7 +63,7 @@ rule generate_data:
generator = ANN_Data_Generator.TrainingDataGenerator(initial_conditions,
left_bound=params.left_bound, right_bound=params.right_bound, balance=params.balance,
stencil_length=params.stencil_length, directory=DIR)
data = generator.build_training_data(params.sample_number, 1)
data = generator.build_training_data(params.sample_number)
# print(data[0])
rule train_model:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment