Skip to content
Snippets Groups Projects
Commit d1225461 authored by Laura Christine Kühle's avatar Laura Christine Kühle
Browse files

Fixed QObject error by using an non-interactive backend for matplotlib.

parent a8cd279a
Branches
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@
"""
import os
import timeit
import time
import numpy as np
import DG_Approximation
......@@ -89,13 +89,13 @@ class TrainingDataGenerator(object):
Dictionary containing input (normalized and non-normalized) and output data.
"""
tic = timeit.default_timer()
tic = time.perf_counter()
print('Calculating training data...\n')
data_dict = self._calculate_data_set(num_samples)
print('Finished calculating training data!')
self._save_data(data_dict)
toc = timeit.default_timer()
toc = time.perf_counter()
print('Total runtime:', toc-tic)
return data_dict
......
......@@ -2,17 +2,23 @@
"""
@author: Laura C. Kühle, Soraya Terrab (sorayaterrab)
Code-Style: E226, W503
Docstring-Style: D200, D400
TODO: Test new ANN set-up with Soraya
TODO: Integrate Main.py into Snakefile
TODO: Integrate Main.py into Snakefile -> Done
TODO: Generalize output for approximate_solution rule -> Done
TODO: Remove object set-up (for more flexibility)
TODO: Add documentation
TODO: Improve log output (+ switch log: to correct pos after params:)
TODO: Improve log output
TODO: Fix logs in Snakefile -> Done
TODO: Throw exception for error due to missing classes
TODO: Fix QObject timer error
TODO: Fix QObject timer error -> Done
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import pyplot as plt
import os
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
......@@ -22,6 +28,7 @@ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc
import ANN_Model
from Plotting import plot_classification_accuracy, plot_boxplot
matplotlib.use('Agg')
class ModelTrainer(object):
def __init__(self, config):
......@@ -121,7 +128,7 @@ class ModelTrainer(object):
# Saving Model
name = self._model_name
# Set paths for plot files if not existing already
# Set paths for files if not existing already
model_dir = self._dir + '/trained models'
if not os.path.exists(model_dir):
os.makedirs(model_dir)
......
......@@ -27,7 +27,8 @@ import os
import numpy as np
from sympy import Symbol
import math
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import pyplot as plt
import Troubled_Cell_Detector
import Initial_Condition
......@@ -36,6 +37,7 @@ import Quadrature
import Update_Scheme
from Basis_Function import OrthonormalLegendre
matplotlib.use('Agg')
x = Symbol('x')
......@@ -298,7 +300,7 @@ class DGScheme(object):
output_matrix[0] = output_matrix[self._num_grid_cells]
output_matrix.append(output_matrix[1])
print(np.array(output_matrix).shape)
# print(np.array(output_matrix).shape)
return np.transpose(np.array(output_matrix))
def build_training_data(self, adjustment, stencil_length, initial_condition=None):
......
......@@ -8,11 +8,13 @@ TODO: Adjust documentation for plot_classification_accuracy()
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from sympy import Symbol
matplotlib.use('Agg')
x = Symbol('x')
z = Symbol('z')
sns.set()
......
import sys
import timeit
import time
import ANN_Data_Generator, Initial_Condition, ANN_Training
from ANN_Training import evaluate_models
......@@ -37,7 +37,7 @@ rule approximate_solution:
sys.stdout = logfile
sys.stderr = logfile
tic = timeit.default_timer()
tic = time.perf_counter()
print(params.dg_params)
dg_scheme = DGScheme(plot_dir=params.plot_dir, **params.dg_params)
......@@ -45,8 +45,8 @@ rule approximate_solution:
dg_scheme.approximate()
dg_scheme.save_plots(params.plot_name)
toc = timeit.default_timer()
print('Time:',toc-tic)
toc = time.perf_counter()
print(f'Time: {toc - tic:0.4f}s')
rule test_model:
input:
......
......@@ -9,7 +9,8 @@ TODO: Load ANN state and config in reset
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from sympy import Symbol
......@@ -18,6 +19,7 @@ import ANN_Model
from Plotting import plot_solution_and_approx, plot_semilog_error, plot_error, plot_shock_tube, \
plot_details, calculate_approximate_solution, calculate_exact_solution
matplotlib.use('Agg')
x = Symbol('x')
z = Symbol('z')
......
......@@ -2,12 +2,13 @@
"""
@author: Laura C. Kühle
TODO: Discuss descriptions (matrices, cfl number, right-hand side, etc.)
TODO: Discuss descriptions (matrices, cfl number, right-hand side, limiting slope, basis, wavelet,
etc.)
TODO: Discuss referencing info on SSPRK3
"""
import numpy as np
import timeit
import time
class UpdateScheme(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment