Skip to content
Snippets Groups Projects
Commit 36e48e46 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Improved runtime of maximum output selection.

parent 9c4b5da7
No related branches found
No related tags found
No related merge requests found
...@@ -203,9 +203,9 @@ class TrainingDataGenerator(object): ...@@ -203,9 +203,9 @@ class TrainingDataGenerator(object):
print('Calculation time:', toc-tic, '\n') print('Calculation time:', toc-tic, '\n')
# Set output data # Set output data
# output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples)
output_data = np.zeros((num_samples, 2)) output_data = np.zeros((num_samples, 2))
output_index = 1 if is_smooth else 0 output_data[:, int(is_smooth)] = np.ones(num_samples)
output_data[:, output_index] = np.ones(num_samples)
return input_data, output_data return input_data, output_data
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
TODO: Add log to pipeline TODO: Add log to pipeline
TODO: Remove object set-up TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation TODO: Optimize Snakefile-vs-config relation
TODO: Improve maximum selection runtime TODO: Improve maximum selection runtime -> Done
TODO: Change output to binary TODO: Change model output to binary -> Do? (changes training when applied in ANN_Model)
TODO: Adapt TCD file to new classification TODO: Adapt TCD file to new classification
TODO: Add evaluation for all classes (recall, precision, fscore) TODO: Add evaluation for all classes (recall, precision, fscore)
TODO: Add documentation TODO: Add documentation
...@@ -106,11 +106,14 @@ class ModelTrainer(object): ...@@ -106,11 +106,14 @@ class ModelTrainer(object):
x_test, y_test = test_set x_test, y_test = test_set
# print(self._model(x_test.float())) # print(self._model(x_test.float()))
model_score = self._model(x_test.float()) model_score = self._model(x_test.float())
model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] # model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
for value in torch.max(model_score, 1)[1]]) # for value in torch.argmax(model_score, dim=1)])
# print(model_output)
y_true = y_test.detach().numpy()[:, 0] model_output = torch.argmax(model_score, dim=1)
y_pred = model_output.detach().numpy()[:, 0] # print(model_output)
y_true = y_test.detach().numpy()[:, 1]
y_pred = model_output.detach().numpy()
# y_score = model_score.detach().numpy()[:, 0] # y_score = model_score.detach().numpy()[:, 0]
accuracy = accuracy_score(y_true, y_pred) accuracy = accuracy_score(y_true, y_pred)
# print('sklearn', accuracy) # print('sklearn', accuracy)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment