Skip to content
Snippets Groups Projects
Commit a594a37c authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Removed unnecessary comments.

parent c4c46355
Branches
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ TODO: Remove object set-up (for more flexibility) ...@@ -7,6 +7,7 @@ TODO: Remove object set-up (for more flexibility)
TODO: Adapt TCD file to new classification TODO: Adapt TCD file to new classification
TODO: Add documentation TODO: Add documentation
TODO: Improve log output TODO: Improve log output
TODO: Remove unnecessary comments -> Done
""" """
import numpy as np import numpy as np
...@@ -101,30 +102,14 @@ class ModelTrainer(object): ...@@ -101,30 +102,14 @@ class ModelTrainer(object):
self._model.eval() self._model.eval()
x_test, y_test = test_set x_test, y_test = test_set
# print(self._model(x_test.float()))
model_score = self._model(x_test.float()) model_score = self._model(x_test.float())
# model_output = torch.tensor([[1.0, 0.0] if value == 0 else [0.0, 1.0]
# for value in torch.argmax(model_score, dim=1)])
# print(model_output)
model_output = torch.argmax(model_score, dim=1) model_output = torch.argmax(model_score, dim=1)
# print(model_output)
y_true = y_test.detach().numpy()[:, 1] y_true = y_test.detach().numpy()[:, 1]
y_pred = model_output.detach().numpy() y_pred = model_output.detach().numpy()
# y_score = model_score.detach().numpy()[:, 0]
accuracy = accuracy_score(y_true, y_pred) accuracy = accuracy_score(y_true, y_pred)
# print('sklearn', accuracy)
precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred) precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
# print(precision, recall, f_score)
# print()
# auroc = roc_auc_score(y_true, y_score)
# print('auroc raw', auroc)
auroc = roc_auc_score(y_true, y_pred) auroc = roc_auc_score(y_true, y_pred)
print('auroc true', auroc)
# fpr, tpr, thresholds = roc_curve(y_true, y_score)
# roc = [tpr, fpr, thresholds]
# print(roc)
# plt.plot(fpr, tpr, label="AUC="+str(auroc))
return {'Precision_Smooth': precision[0], 'Precision_Troubled': precision[1], return {'Precision_Smooth': precision[0], 'Precision_Troubled': precision[1],
'Recall_Smooth': recall[0], 'Recall_Troubled': recall[1], 'Recall_Smooth': recall[0], 'Recall_Troubled': recall[1],
...@@ -143,9 +128,6 @@ class ModelTrainer(object): ...@@ -143,9 +128,6 @@ class ModelTrainer(object):
torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt') torch.save(self._model.state_dict(), model_dir + '/model__' + name + '.pt')
torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt') torch.save(self._validation_loss, model_dir + '/loss__' + name + '.pt')
# def _classify(self):
# pass
def read_training_data(directory, normalized=True): def read_training_data(directory, normalized=True):
# Get training dataset from saved file and map to Torch tensor and dataset # Get training dataset from saved file and map to Torch tensor and dataset
...@@ -170,7 +152,6 @@ def evaluate_models(models, directory, num_iterations=100, colors=None, ...@@ -170,7 +152,6 @@ def evaluate_models(models, directory, num_iterations=100, colors=None,
for iteration in range(num_iterations): for iteration in range(num_iterations):
for train_index, test_index in KFold( for train_index, test_index in KFold(
n_splits=5, shuffle=True).split(datasets['normalized']): n_splits=5, shuffle=True).split(datasets['normalized']):
# print("TRAIN:", train_index, "TEST:", test_index)
for dataset in datasets.keys(): for dataset in datasets.keys():
training_set = TensorDataset(*datasets[dataset][train_index]) training_set = TensorDataset(*datasets[dataset][train_index])
test_set = datasets[dataset][test_index] test_set = datasets[dataset][test_index]
...@@ -200,12 +181,3 @@ def evaluate_models(models, directory, num_iterations=100, colors=None, ...@@ -200,12 +181,3 @@ def evaluate_models(models, directory, num_iterations=100, colors=None,
plt.figure(identifier) plt.figure(identifier)
plt.savefig(plot_dir + '/' + identifier + '/' + '_'.join(models.keys()) + '.pdf') plt.savefig(plot_dir + '/' + identifier + '/' + '_'.join(models.keys()) + '.pdf')
# Loss Functions: BCELoss, BCEWithLogitsLoss,
# CrossEntropyLoss (not working), MSELoss (with reduction='sum')
# Optimizer: Adam, SGD
# trainer = ModelTrainer({'num_epochs': 1000})
# trainer.epoch_training()
# trainer.test_model()
# trainer.save_model()
...@@ -156,7 +156,7 @@ def plot_details(fine_projection, fine_mesh, coarse_projection, basis, wavelet, ...@@ -156,7 +156,7 @@ def plot_details(fine_projection, fine_mesh, coarse_projection, basis, wavelet,
def calculate_approximate_solution(projection, points, polynomial_degree, basis): def calculate_approximate_solution(projection, points, polynomial_degree, basis):
""""Calculates approximate solution. """Calculates approximate solution.
Parameters Parameters
---------- ----------
...@@ -273,7 +273,6 @@ def plot_classification_accuracy(evaluation_dict, colors): ...@@ -273,7 +273,6 @@ def plot_classification_accuracy(evaluation_dict, colors):
ax.set_ylim(top=1.02) ax.set_ylim(top=1.02)
ax.set_title('Classification Evaluation (Barplot)') ax.set_title('Classification Evaluation (Barplot)')
ax.legend(loc='upper right') ax.legend(loc='upper right')
# fig.tight_layout()
def plot_boxplot(evaluation_dict, colors): def plot_boxplot(evaluation_dict, colors):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment