From 96b32ff74c54bb1599fde9b3deb93bf16e4fe09d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=BChle=2C=20Laura=20Christine=20=28lakue103=29?= <laura.kuehle@uni-duesseldorf.de> Date: Tue, 21 Dec 2021 22:18:02 +0100 Subject: [PATCH] Added option to compare models on raw and normalized data visually. --- ANN_Training.py | 39 ++++++++++++++++++++++++--------------- Plotting.py | 10 ++++++---- Snakefile | 5 +++-- config.yaml | 1 + 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/ANN_Training.py b/ANN_Training.py index ead8e26..2e9dd24 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -8,7 +8,7 @@ TODO: Optimize Snakefile-vs-config relation TODO: Improve maximum selection runtime TODO: Change output to binary TODO: Adapt TCD file to new classification -TODO: Add flag for evaluation of non-normalized data as well -> Next! +TODO: Add flag for evaluation of non-normalized data as well -> Done TODO: Add evaluation for all classes (recall, precision, fscore) TODO: Add documentation @@ -146,33 +146,42 @@ class ModelTrainer(object): # pass -def read_training_data(directory): +def read_training_data(directory, normalized=True): # Get training dataset from saved file and map to Torch tensor and dataset - input_file = directory + '/input_data.npy' + input_file = directory + ('/normalized_input_data.npy' if normalized else '/input_data.npy') output_file = directory + '/output_data.npy' return TensorDataset(*map(torch.tensor, (np.load(input_file), np.load(output_file)))) -def evaluate_models(models, directory, num_iterations=100, colors=None): +def evaluate_models(models, directory, num_iterations=100, colors=None, + compare_normalization=False): if colors is None: colors = {'Accuracy': 'red', 'Precision': 'yellow', 'Recall': 'blue', 'F-Score': 'green', 'AUROC': 'purple'} - dataset = read_training_data(directory) - classification_stats = {measure: {model: [] for model in models} for measure in colors} + + datasets = {'normalized': read_training_data(directory)} + if compare_normalization: + datasets['raw'] = read_training_data(directory, False) + classification_stats = {measure: {model + ' (' + dataset + ')': [] for model in models + for dataset in datasets} for measure in colors} for iteration in range(num_iterations): - for train_index, test_index in KFold(n_splits=5, shuffle=True).split(dataset): + for train_index, test_index in KFold( + n_splits=5, shuffle=True).split(datasets['normalized']): # print("TRAIN:", train_index, "TEST:", test_index) - training_set = TensorDataset(*dataset[train_index]) - test_set = dataset[test_index] + for dataset in datasets.keys(): + training_set = TensorDataset(*datasets[dataset][train_index]) + test_set = datasets[dataset][test_index] - for model in models: - result = models[model].test_model(training_set, test_set) - for measure in colors: - classification_stats[measure][model].append(result[measure]) + for model in models: + result = models[model].test_model(training_set, test_set) + for measure in colors: + classification_stats[measure][model + ' (' + dataset + ')'].append( + result[measure]) plot_boxplot(classification_stats, colors) - classification_stats = {measure: {model: np.array(classification_stats[measure][model]).mean() - for model in models} for measure in colors} + classification_stats = {measure: {model + ' (' + dataset + ')': np.array( + classification_stats[measure][model + ' (' + dataset + ')']).mean() for model in models + for dataset in datasets} for measure in colors} plot_classification_accuracy(classification_stats, colors) # Set paths for plot files if not existing already diff --git a/Plotting.py b/Plotting.py index a50ba04..c17298b 100644 --- a/Plotting.py +++ b/Plotting.py @@ -255,10 +255,11 @@ def plot_classification_accuracy(evaluation_dict, colors): """ model_names = evaluation_dict[list(colors.keys())[0]].keys() + font_size = 16 - (len(max(model_names, key=len))//3) pos = np.arange(len(model_names)) width = 1/(3*len(model_names)) fig = plt.figure('classification_accuracy') - ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) + ax = fig.add_axes([0.15, 0.3, 0.75, 0.6]) step_len = 1 adjustment = -(len(model_names)//2)*step_len for measure in evaluation_dict: @@ -266,7 +267,7 @@ def plot_classification_accuracy(evaluation_dict, colors): ax.bar(pos + adjustment*width, model_eval, width, label=measure, color=colors[measure]) adjustment += step_len ax.set_xticks(pos) - ax.set_xticklabels(model_names) + ax.set_xticklabels(model_names, rotation=50, ha='right', fontsize=font_size) ax.set_ylabel('Classification (%)') ax.set_ylim(bottom=-0.02) ax.set_ylim(top=1.02) @@ -277,8 +278,9 @@ def plot_classification_accuracy(evaluation_dict, colors): def plot_boxplot(evaluation_dict, colors): model_names = evaluation_dict[list(colors.keys())[0]].keys() + font_size = 16 - (len(max(model_names, key=len))//3) fig = plt.figure('boxplot_accuracy') - ax = fig.add_axes([0.15, 0.1, 0.75, 0.8]) + ax = fig.add_axes([0.15, 0.3, 0.75, 0.6]) step_len = 1.5 boxplots = [] adjustment = -(len(model_names)//2)*step_len @@ -294,7 +296,7 @@ def plot_boxplot(evaluation_dict, colors): adjustment += step_len ax.set_xticks(pos) - ax.set_xticklabels(model_names) + ax.set_xticklabels(model_names, rotation=50, ha='right', fontsize=font_size) ax.set_ylim(bottom=-0.02) ax.set_ylim(top=1.02) ax.set_ylabel('Classification (%)') diff --git a/Snakefile b/Snakefile index bd1e445..9470ab3 100644 --- a/Snakefile +++ b/Snakefile @@ -26,14 +26,15 @@ rule test_model: output: DIR+'/model evaluation/classification_accuracy/' + '_'.join(MODELS.keys()) + '.pdf' params: - colors = config['classification_colors'] + colors = config['classification_colors'], + compare_normalization = config['compare_normalization'] run: models = {} for model in MODELS: trainer= ANN_Training.ModelTrainer({'model_name': model, 'dir': DIR, 'model_dir': DIR, **MODELS[model]}) models[model] = trainer - evaluate_models(models, DIR, 2, params.colors) + evaluate_models(models, DIR, 2, params.colors, params.compare_normalization) rule generate_data: output: diff --git a/config.yaml b/config.yaml index 3362a67..86171ca 100644 --- a/config.yaml +++ b/config.yaml @@ -24,6 +24,7 @@ functions: adjustment: 0 # Parameter for Model Training and Evaluation +compare_normalization: True classification_colors: Accuracy: 'magenta' Precision: 'red' -- GitLab