diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index 14e395cdfd26c7eca10320e547e7642f74cc96bd..1c80916ccda472d0cc401bde17c17911ed8ce3cb 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -203,9 +203,9 @@ class TrainingDataGenerator(object): print('Calculation time:', toc-tic, '\n') # Set output data + # output_data = np.zeros(num_samples) if is_smooth else np.ones(num_samples) output_data = np.zeros((num_samples, 2)) - output_index = 1 if is_smooth else 0 - output_data[:, output_index] = np.ones(num_samples) + output_data[:, int(is_smooth)] = np.ones(num_samples) return input_data, output_data diff --git a/ANN_Training.py b/ANN_Training.py index b51bb07eb9716ce0f22a9f974aa57a1a31c618e5..243993690fb16dce3ccc2796668d338ee59affba 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -5,8 +5,8 @@ TODO: Add log to pipeline TODO: Remove object set-up TODO: Optimize Snakefile-vs-config relation -TODO: Improve maximum selection runtime -TODO: Change output to binary +TODO: Improve maximum selection runtime -> Done +TODO: Change model output to binary -> Do? (changes training when applied in ANN_Model) TODO: Adapt TCD file to new classification TODO: Add evaluation for all classes (recall, precision, fscore) TODO: Add documentation @@ -106,11 +106,14 @@ class ModelTrainer(object): x_test, y_test = test_set # print(self._model(x_test.float())) model_score = self._model(x_test.float()) - model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] - for value in torch.max(model_score, 1)[1]]) - - y_true = y_test.detach().numpy()[:, 0] - y_pred = model_output.detach().numpy()[:, 0] + # model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0] + # for value in torch.argmax(model_score, dim=1)]) + # print(model_output) + model_output = torch.argmax(model_score, dim=1) + # print(model_output) + + y_true = y_test.detach().numpy()[:, 1] + y_pred = model_output.detach().numpy() # y_score = model_score.detach().numpy()[:, 0] accuracy = accuracy_score(y_true, y_pred) # print('sklearn', accuracy)