Skip to content
Snippets Groups Projects
Commit 96b32ff7 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Added option to compare models on raw and normalized data visually.

parent c9aeb432
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ TODO: Optimize Snakefile-vs-config relation ...@@ -8,7 +8,7 @@ TODO: Optimize Snakefile-vs-config relation
TODO: Improve maximum selection runtime TODO: Improve maximum selection runtime
TODO: Change output to binary TODO: Change output to binary
TODO: Adapt TCD file to new classification TODO: Adapt TCD file to new classification
TODO: Add flag for evaluation of non-normalized data as well -> Next! TODO: Add flag for evaluation of non-normalized data as well -> Done
TODO: Add evaluation for all classes (recall, precision, fscore) TODO: Add evaluation for all classes (recall, precision, fscore)
TODO: Add documentation TODO: Add documentation
...@@ -146,33 +146,42 @@ class ModelTrainer(object): ...@@ -146,33 +146,42 @@ class ModelTrainer(object):
# pass # pass
def read_training_data(directory): def read_training_data(directory, normalized=True):
# Get training dataset from saved file and map to Torch tensor and dataset # Get training dataset from saved file and map to Torch tensor and dataset
input_file = directory + '/input_data.npy' input_file = directory + ('/normalized_input_data.npy' if normalized else '/input_data.npy')
output_file = directory + '/output_data.npy' output_file = directory + '/output_data.npy'
return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file)))) return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file))))
def evaluate_models(models, directory, num_iterations=100, colors=None): def evaluate_models(models, directory, num_iterations=100, colors=None,
compare_normalization=False):
if colors is None: if colors is None:
colors = {'Accuracy': 'red', 'Precision': 'yellow', 'Recall': 'blue', colors = {'Accuracy': 'red', 'Precision': 'yellow', 'Recall': 'blue',
'F-Score': 'green', 'AUROC': 'purple'} 'F-Score': 'green', 'AUROC': 'purple'}
dataset = read_training_data(directory)
classification_stats = {measure: {model: [] for model in models} for measure in colors} datasets = {'normalized': read_training_data(directory)}
if compare_normalization:
datasets['raw'] = read_training_data(directory, False)
classification_stats = {measure: {model + ' (' + dataset + ')': [] for model in models
for dataset in datasets} for measure in colors}
for iteration in range(num_iterations): for iteration in range(num_iterations):
for train_index, test_index in KFold(n_splits=5, shuffle=True).split(dataset): for train_index, test_index in KFold(
n_splits=5, shuffle=True).split(datasets['normalized']):
# print("TRAIN:", train_index, "TEST:", test_index) # print("TRAIN:", train_index, "TEST:", test_index)
training_set = TensorDataset(*dataset[train_index]) for dataset in datasets.keys():
test_set = dataset[test_index] training_set = TensorDataset(*datasets[dataset][train_index])
test_set = datasets[dataset][test_index]
for model in models: for model in models:
result = models[model].test_model(training_set, test_set) result = models[model].test_model(training_set, test_set)
for measure in colors: for measure in colors:
classification_stats[measure][model].append(result[measure]) classification_stats[measure][model + ' (' + dataset + ')'].append(
result[measure])
plot_boxplot(classification_stats, colors) plot_boxplot(classification_stats, colors)
classification_stats = {measure: {model: np.array(classification_stats[measure][model]).mean() classification_stats = {measure: {model + ' (' + dataset + ')': np.array(
for model in models} for measure in colors} classification_stats[measure][model + ' (' + dataset + ')']).mean() for model in models
for dataset in datasets} for measure in colors}
plot_classification_accuracy(classification_stats, colors) plot_classification_accuracy(classification_stats, colors)
# Set paths for plot files if not existing already # Set paths for plot files if not existing already
......
...@@ -255,10 +255,11 @@ def plot_classification_accuracy(evaluation_dict, colors): ...@@ -255,10 +255,11 @@ def plot_classification_accuracy(evaluation_dict, colors):
""" """
model_names = evaluation_dict[list(colors.keys())[0]].keys() model_names = evaluation_dict[list(colors.keys())[0]].keys()
font_size = 16 - (len(max(model_names, key=len))//3)
pos = np.arange(len(model_names)) pos = np.arange(len(model_names))
width = 1/(3*len(model_names)) width = 1/(3*len(model_names))
fig = plt.figure('classification_accuracy') fig = plt.figure('classification_accuracy')
ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) ax = fig.add_axes([0.15, 0.3, 0.75, 0.6])
step_len = 1 step_len = 1
adjustment = -(len(model_names)//2)*step_len adjustment = -(len(model_names)//2)*step_len
for measure in evaluation_dict: for measure in evaluation_dict:
...@@ -266,7 +267,7 @@ def plot_classification_accuracy(evaluation_dict, colors): ...@@ -266,7 +267,7 @@ def plot_classification_accuracy(evaluation_dict, colors):
ax.bar(pos + adjustment*width, model_eval, width, label=measure, color=colors[measure]) ax.bar(pos + adjustment*width, model_eval, width, label=measure, color=colors[measure])
adjustment += step_len adjustment += step_len
ax.set_xticks(pos) ax.set_xticks(pos)
ax.set_xticklabels(model_names) ax.set_xticklabels(model_names, rotation=50, ha='right', fontsize=font_size)
ax.set_ylabel('Classification (%)') ax.set_ylabel('Classification (%)')
ax.set_ylim(bottom=-0.02) ax.set_ylim(bottom=-0.02)
ax.set_ylim(top=1.02) ax.set_ylim(top=1.02)
...@@ -277,8 +278,9 @@ def plot_classification_accuracy(evaluation_dict, colors): ...@@ -277,8 +278,9 @@ def plot_classification_accuracy(evaluation_dict, colors):
def plot_boxplot(evaluation_dict, colors): def plot_boxplot(evaluation_dict, colors):
model_names = evaluation_dict[list(colors.keys())[0]].keys() model_names = evaluation_dict[list(colors.keys())[0]].keys()
font_size = 16 - (len(max(model_names, key=len))//3)
fig = plt.figure('boxplot_accuracy') fig = plt.figure('boxplot_accuracy')
ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) ax = fig.add_axes([0.15, 0.3, 0.75, 0.6])
step_len = 1.5 step_len = 1.5
boxplots = [] boxplots = []
adjustment = -(len(model_names)//2)*step_len adjustment = -(len(model_names)//2)*step_len
...@@ -294,7 +296,7 @@ def plot_boxplot(evaluation_dict, colors): ...@@ -294,7 +296,7 @@ def plot_boxplot(evaluation_dict, colors):
adjustment += step_len adjustment += step_len
ax.set_xticks(pos) ax.set_xticks(pos)
ax.set_xticklabels(model_names) ax.set_xticklabels(model_names, rotation=50, ha='right', fontsize=font_size)
ax.set_ylim(bottom=-0.02) ax.set_ylim(bottom=-0.02)
ax.set_ylim(top=1.02) ax.set_ylim(top=1.02)
ax.set_ylabel('Classification (%)') ax.set_ylabel('Classification (%)')
......
...@@ -26,14 +26,15 @@ rule test_model: ...@@ -26,14 +26,15 @@ rule test_model:
output: output:
DIR+'/model evaluation/classification_accuracy/' + '_'.join(MODELS.keys()) + '.pdf' DIR+'/model evaluation/classification_accuracy/' + '_'.join(MODELS.keys()) + '.pdf'
params: params:
colors = config['classification_colors'] colors = config['classification_colors'],
compare_normalization = config['compare_normalization']
run: run:
models = {} models = {}
for model in MODELS: for model in MODELS:
trainer= ANN_Training.ModelTrainer({'model_name': model, 'dir': DIR, trainer= ANN_Training.ModelTrainer({'model_name': model, 'dir': DIR,
'model_dir': DIR, **MODELS[model]}) 'model_dir': DIR, **MODELS[model]})
models[model] = trainer models[model] = trainer
evaluate_models(models, DIR, 2, params.colors) evaluate_models(models, DIR, 2, params.colors, params.compare_normalization)
rule generate_data: rule generate_data:
output: output:
......
...@@ -24,6 +24,7 @@ functions: ...@@ -24,6 +24,7 @@ functions:
adjustment: 0 adjustment: 0
# Parameter for Model Training and Evaluation # Parameter for Model Training and Evaluation
compare_normalization: True
classification_colors: classification_colors:
Accuracy: 'magenta' Accuracy: 'magenta'
Precision: 'red' Precision: 'red'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment