diff --git a/ANN_Data_Generator.py b/ANN_Data_Generator.py index 1c80916ccda472d0cc401bde17c17911ed8ce3cb..1a799fb5e2a24957356ee01fa52deb6dd0351ace 100644 --- a/ANN_Data_Generator.py +++ b/ANN_Data_Generator.py @@ -4,7 +4,7 @@ """ import os -import timeit +import time import numpy as np import DG_Approximation @@ -89,13 +89,13 @@ class TrainingDataGenerator(object): Dictionary containing input (normalized and non-normalized) and output data. """ - tic = timeit.default_timer() + tic = time.perf_counter() print('Calculating training data...\n') data_dict = self._calculate_data_set(num_samples) print('Finished calculating training data!') self._save_data(data_dict) - toc = timeit.default_timer() + toc = time.perf_counter() print('Total runtime:', toc-tic) return data_dict diff --git a/ANN_Training.py b/ANN_Training.py index cd09c152004aef8ed59a47ed05247a8bbf9a8e13..79131b8f7d6f5e45745e2deda6e1b66fec9e73c4 100644 --- a/ANN_Training.py +++ b/ANN_Training.py @@ -2,17 +2,23 @@ """ @author: Laura C. Kühle, Soraya Terrab (sorayaterrab) +Code-Style: E226, W503 +Docstring-Style: D200, D400 + TODO: Test new ANN set-up with Soraya -TODO: Integrate Main.py into Snakefile +TODO: Integrate Main.py into Snakefile -> Done +TODO: Generalize output for approximate_solution rule -> Done TODO: Remove object set-up (for more flexibility) TODO: Add documentation -TODO: Improve log output (+ switch log: to correct pos after params:) +TODO: Improve log output +TODO: Fix logs in Snakefile -> Done TODO: Throw exception for error due to missing classes -TODO: Fix QObject timer error +TODO: Fix QObject timer error -> Done """ import numpy as np -import matplotlib.pyplot as plt +import matplotlib +from matplotlib import pyplot as plt import os import torch from torch.utils.data import TensorDataset, DataLoader, random_split @@ -22,6 +28,7 @@ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc import ANN_Model from Plotting import plot_classification_accuracy, plot_boxplot +matplotlib.use('Agg') class ModelTrainer(object): def __init__(self, config): @@ -121,7 +128,7 @@ class ModelTrainer(object): # Saving Model name = self._model_name - # Set paths for plot files if not existing already + # Set paths for files if not existing already model_dir = self._dir + '/trained models' if not os.path.exists(model_dir): os.makedirs(model_dir) diff --git a/DG_Approximation.py b/DG_Approximation.py index 6336f784691403f4bfcb22ae9bc6e8f7f305eb1b..a4114a35ba0dfacf83da02d11de44bdc3b95d776 100644 --- a/DG_Approximation.py +++ b/DG_Approximation.py @@ -27,7 +27,8 @@ import os import numpy as np from sympy import Symbol import math -import matplotlib.pyplot as plt +import matplotlib +from matplotlib import pyplot as plt import Troubled_Cell_Detector import Initial_Condition @@ -36,6 +37,7 @@ import Quadrature import Update_Scheme from Basis_Function import OrthonormalLegendre +matplotlib.use('Agg') x = Symbol('x') @@ -298,7 +300,7 @@ class DGScheme(object): output_matrix[0] = output_matrix[self._num_grid_cells] output_matrix.append(output_matrix[1]) - print(np.array(output_matrix).shape) + # print(np.array(output_matrix).shape) return np.transpose(np.array(output_matrix)) def build_training_data(self, adjustment, stencil_length, initial_condition=None): diff --git a/Plotting.py b/Plotting.py index 16019dc50e97779eb12b6f993f9ba5b91cb51c18..ff5e608617c5ebe88c11abf1662e5142a5abf2a4 100644 --- a/Plotting.py +++ b/Plotting.py @@ -8,11 +8,13 @@ TODO: Adjust documentation for plot_classification_accuracy() """ import numpy as np -import matplotlib.pyplot as plt +import matplotlib +from matplotlib import pyplot as plt import seaborn as sns from sympy import Symbol +matplotlib.use('Agg') x = Symbol('x') z = Symbol('z') sns.set() diff --git a/Snakefile b/Snakefile index ae469804a415297b3502ea1cf66baf53279718a4..a170be882a9978012bea3f2d12cc801a80780090 100644 --- a/Snakefile +++ b/Snakefile @@ -1,5 +1,5 @@ import sys -import timeit +import time import ANN_Data_Generator, Initial_Condition, ANN_Training from ANN_Training import evaluate_models @@ -37,7 +37,7 @@ rule approximate_solution: sys.stdout = logfile sys.stderr = logfile - tic = timeit.default_timer() + tic = time.perf_counter() print(params.dg_params) dg_scheme = DGScheme(plot_dir=params.plot_dir, **params.dg_params) @@ -45,8 +45,8 @@ rule approximate_solution: dg_scheme.approximate() dg_scheme.save_plots(params.plot_name) - toc = timeit.default_timer() - print('Time:',toc-tic) + toc = time.perf_counter() + print(f'Time: {toc - tic:0.4f}s') rule test_model: input: diff --git a/Troubled_Cell_Detector.py b/Troubled_Cell_Detector.py index ef91fcc84c81aea9ccc9ef2ac9173afa76d7da66..30282470b8d54c688d8cf4b5295f323239026a1c 100644 --- a/Troubled_Cell_Detector.py +++ b/Troubled_Cell_Detector.py @@ -9,7 +9,8 @@ TODO: Load ANN state and config in reset """ import numpy as np -import matplotlib.pyplot as plt +import matplotlib +from matplotlib import pyplot as plt import seaborn as sns import torch from sympy import Symbol @@ -18,6 +19,7 @@ import ANN_Model from Plotting import plot_solution_and_approx, plot_semilog_error, plot_error, plot_shock_tube, \ plot_details, calculate_approximate_solution, calculate_exact_solution +matplotlib.use('Agg') x = Symbol('x') z = Symbol('z') diff --git a/Update_Scheme.py b/Update_Scheme.py index f5dc8d8b1f2adc4f354a4f060b2a23acefd015a7..5d91109a0eb175634b76ebabb3f07c813dc077ef 100644 --- a/Update_Scheme.py +++ b/Update_Scheme.py @@ -2,12 +2,13 @@ """ @author: Laura C. Kühle -TODO: Discuss descriptions (matrices, cfl number, right-hand side, etc.) +TODO: Discuss descriptions (matrices, cfl number, right-hand side, limiting slope, basis, wavelet, + etc.) TODO: Discuss referencing info on SSPRK3 """ import numpy as np -import timeit +import time class UpdateScheme(object):