Skip to content
Snippets Groups Projects
Commit 36e48e46 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Improved runtime of maximum output selection.

parent 9c4b5da7
Branches
No related tags found
No related merge requests found
......@@ -203,9 +203,9 @@ class TrainingDataGenerator(object):
print('Calculation time:', toc-tic, '\n')
# Set output data
# output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples)
output_data = np.zeros((num_samples, 2))
output_index = 1 if is_smooth else 0
output_data[:, output_index] = np.ones(num_samples)
output_data[:, int(is_smooth)] = np.ones(num_samples)
return input_data, output_data
......
......@@ -5,8 +5,8 @@
TODO: Add log to pipeline
TODO: Remove object set-up
TODO: Optimize Snakefile-vs-config relation
TODO: Improve maximum selection runtime
TODO: Change output to binary
TODO: Improve maximum selection runtime -> Done
TODO: Change model output to binary -> Do? (changes training when applied in ANN_Model)
TODO: Adapt TCD file to new classification
TODO: Add evaluation for all classes (recall, precision, fscore)
TODO: Add documentation
......@@ -106,11 +106,14 @@ class ModelTrainer(object):
x_test, y_test = test_set
# print(self._model(x_test.float()))
model_score = self._model(x_test.float())
model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
for value in torch.max(model_score, 1)[1]])
y_true = y_test.detach().numpy()[:, 0]
y_pred = model_output.detach().numpy()[:, 0]
# model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
# for value in torch.argmax(model_score, dim=1)])
# print(model_output)
model_output = torch.argmax(model_score, dim=1)
# print(model_output)
y_true = y_test.detach().numpy()[:, 1]
y_pred = model_output.detach().numpy()
# y_score = model_score.detach().numpy()[:, 0]
accuracy = accuracy_score(y_true, y_pred)
# print('sklearn', accuracy)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment