Skip to content
Snippets Groups Projects
Commit a594a37c authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Removed unnecessary comments.

parent c4c46355
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ TODO: Remove object set-up (for more flexibility)
TODO: Adapt TCD file to new classification
TODO: Add documentation
TODO: Improve log output
TODO: Remove unnecessary comments -> Done
"""
import numpy as np
......@@ -101,30 +102,14 @@ class ModelTrainer(object):
self._model.eval()
x_test, y_test = test_set
# print(self._model(x_test.float()))
model_score = self._model(x_test.float())
# model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
# for value in torch.argmax(model_score, dim=1)])
# print(model_output)
model_output = torch.argmax(model_score, dim=1)
# print(model_output)
y_true = y_test.detach().numpy()[:, 1]
y_pred = model_output.detach().numpy()
# y_score = model_score.detach().numpy()[:, 0]
accuracy = accuracy_score(y_true, y_pred)
# print('sklearn', accuracy)
precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
# print(precision, recall, f_score)
# print()
# auroc = roc_auc_score(y_true, y_score)
# print('auroc raw', auroc)
auroc = roc_auc_score(y_true, y_pred)
print('auroc true', auroc)
# fpr, tpr, thresholds = roc_curve(y_true, y_score)
# roc = [tpr, fpr, thresholds]
# print(roc)
# plt.plot(fpr, tpr, label="AUC="+str(auroc))
return {'Precision_Smooth': precision[0], 'Precision_Troubled': precision[1],
'Recall_Smooth': recall[0], 'Recall_Troubled': recall[1],
......@@ -143,9 +128,6 @@ class ModelTrainer(object):
torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt')
torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt')
# def _classify(self):
# pass
def read_training_data(directory, normalized=True):
# Get training dataset from saved file and map to Torch tensor and dataset
......@@ -170,7 +152,6 @@ def evaluate_models(models, directory, num_iterations=100, colors=None,
for iteration in range(num_iterations):
for train_index, test_index in KFold(
n_splits=5, shuffle=True).split(datasets['normalized']):
# print("TRAIN:", train_index, "TEST:", test_index)
for dataset in datasets.keys():
training_set = TensorDataset(*datasets[dataset][train_index])
test_set = datasets[dataset][test_index]
......@@ -200,12 +181,3 @@ def evaluate_models(models, directory, num_iterations=100, colors=None,
plt.figure(identifier)
plt.savefig(plot_dir + '/' + identifier + '/' + '_'.join(models.keys()) + '.pdf')
# Loss Functions: BCELoss, BCEWithLogitsLoss,
# CrossEntropyLoss (not working), MSELoss (with reduction='sum')
# Optimizer: Adam, SGD
# trainer = ModelTrainer({'num_epochs': 1000})
# trainer.epoch_training()
# trainer.test_model()
# trainer.save_model()
......@@ -156,7 +156,7 @@ def plot_details(fine_projection, fine_mesh, coarse_projection, basis, wavelet,
def calculate_approximate_solution(projection, points, polynomial_degree, basis):
""""Calculates approximate solution.
"""Calculates approximate solution.
Parameters
----------
......@@ -273,7 +273,6 @@ def plot_classification_accuracy(evaluation_dict, colors):
ax.set_ylim(top=1.02)
ax.set_title('Classification Evaluation (Barplot)')
ax.legend(loc='upper right')
# fig.tight_layout()
def plot_boxplot(evaluation_dict, colors):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment