diff --git a/ANN_Data_Generator.py b/scripts/tcd/ANN_Data_Generator.py similarity index 98% rename from ANN_Data_Generator.py rename to scripts/tcd/ANN_Data_Generator.py index 995c04317a4c14a056aa33bc0e0ab1561643f8c6..07b88e21c1c2c9c5267d1b429f35f5e83e13b85b 100644 --- a/ANN_Data_Generator.py +++ b/scripts/tcd/ANN_Data_Generator.py @@ -7,10 +7,10 @@ import os import time import numpy as np -from DG_Approximation import do_initial_projection -from projection_utils import Mesh -from Quadrature import Gauss -from Basis_Function import OrthonormalLegendre +from .DG_Approximation import do_initial_projection +from .projection_utils import Mesh +from .Quadrature import Gauss +from .Basis_Function import OrthonormalLegendre class TrainingDataGenerator: diff --git a/ANN_Model.py b/scripts/tcd/ANN_Model.py similarity index 100% rename from ANN_Model.py rename to scripts/tcd/ANN_Model.py diff --git a/ANN_Training.py b/scripts/tcd/ANN_Training.py similarity index 99% rename from ANN_Training.py rename to scripts/tcd/ANN_Training.py index 2586150ee46b7dcae7d87615132559cce0013b4f..0b772ca8b1f26c4e6ff81284888f61a969191469 100644 --- a/ANN_Training.py +++ b/scripts/tcd/ANN_Training.py @@ -16,7 +16,7 @@ from sklearn.model_selection import KFold from sklearn.metrics import accuracy_score, precision_recall_fscore_support, \ roc_auc_score -import ANN_Model +from . import ANN_Model class ModelTrainer: diff --git a/Basis_Function.py b/scripts/tcd/Basis_Function.py similarity index 99% rename from Basis_Function.py rename to scripts/tcd/Basis_Function.py index 910a90a5a3ef02044dc350898c63069779efb117..758dc2a32ca782e2cdd5fbb357edbce6387d3e57 100644 --- a/Basis_Function.py +++ b/scripts/tcd/Basis_Function.py @@ -11,7 +11,7 @@ import numpy as np from numpy import ndarray from sympy import Symbol, integrate -from projection_utils import calculate_approximate_solution +from .projection_utils import calculate_approximate_solution x = Symbol('x') z = Symbol('z') diff --git a/DG_Approximation.py b/scripts/tcd/DG_Approximation.py similarity index 97% rename from DG_Approximation.py rename to scripts/tcd/DG_Approximation.py index 3e1193cefeceaedf15d6f2c731c0761d95eee66c..5c2bb512ebf86a41b990acbb8a16f35458897afd 100644 --- a/DG_Approximation.py +++ b/scripts/tcd/DG_Approximation.py @@ -22,7 +22,9 @@ TODO: Discuss referencing info on SSPRK3 TODO: Discuss name for quadrature mesh (now: grid) Urgent: -TODO: Move scripts into separate directory +TODO: Build package for DG scheme -> Done +TODO: Move scripts into separate directory -> Done +TODO: Move TODOs to Snakefile TODO: Outsource run commands in SM rules into separate files TODO: Move plot_approximation_results() into plotting script TODO: Introduce env files for each SM rule @@ -44,7 +46,6 @@ TODO: Investigate profiling for speed up Currently not critical: TODO: Add an environment file for Snakemake -TODO: Build package (module?) for DG scheme TODO: Rename files according to standard TODO: Allow comparison between ANN training datasets TODO: Use full path for ANN model state @@ -79,14 +80,14 @@ from sympy import Symbol import math import seaborn as sns -import Troubled_Cell_Detector -import Initial_Condition -import Limiter -import Quadrature -import Update_Scheme -from Basis_Function import OrthonormalLegendre -from encoding_utils import encode_ndarray -from projection_utils import Mesh +from . import Troubled_Cell_Detector +from . import Initial_Condition +from . import Limiter +from . import Quadrature +from . import Update_Scheme +from .Basis_Function import OrthonormalLegendre +from .encoding_utils import encode_ndarray +from .projection_utils import Mesh x = Symbol('x') sns.set() diff --git a/Initial_Condition.py b/scripts/tcd/Initial_Condition.py similarity index 100% rename from Initial_Condition.py rename to scripts/tcd/Initial_Condition.py diff --git a/Limiter.py b/scripts/tcd/Limiter.py similarity index 100% rename from Limiter.py rename to scripts/tcd/Limiter.py diff --git a/Plotting.py b/scripts/tcd/Plotting.py similarity index 98% rename from Plotting.py rename to scripts/tcd/Plotting.py index cab96326eb31c09ddf53e95b392342e3379c854d..43dcaa3b725da8afb6f48b8073ca4352b8646374 100644 --- a/Plotting.py +++ b/scripts/tcd/Plotting.py @@ -14,12 +14,12 @@ import seaborn as sns from numpy import ndarray from sympy import Symbol -from Quadrature import Quadrature -from Initial_Condition import InitialCondition -from Basis_Function import Basis, OrthonormalLegendre -from projection_utils import calculate_exact_solution,\ +from .Quadrature import Quadrature +from .Initial_Condition import InitialCondition +from .Basis_Function import Basis, OrthonormalLegendre +from .projection_utils import calculate_exact_solution,\ calculate_approximate_solution, Mesh -from encoding_utils import decode_ndarray +from .encoding_utils import decode_ndarray matplotlib.use('Agg') diff --git a/Quadrature.py b/scripts/tcd/Quadrature.py similarity index 100% rename from Quadrature.py rename to scripts/tcd/Quadrature.py diff --git a/Troubled_Cell_Detector.py b/scripts/tcd/Troubled_Cell_Detector.py similarity index 99% rename from Troubled_Cell_Detector.py rename to scripts/tcd/Troubled_Cell_Detector.py index cd7f98acdda727b7f956966cf8f7877bf4569b99..52f4cc2b4c475f8d1c90961fdb06e2a1391efe68 100644 --- a/Troubled_Cell_Detector.py +++ b/scripts/tcd/Troubled_Cell_Detector.py @@ -7,8 +7,8 @@ from abc import ABC, abstractmethod import numpy as np import torch -import ANN_Model -from projection_utils import Mesh +from . import ANN_Model +from .projection_utils import Mesh class TroubledCellDetector(ABC): diff --git a/Update_Scheme.py b/scripts/tcd/Update_Scheme.py similarity index 100% rename from Update_Scheme.py rename to scripts/tcd/Update_Scheme.py diff --git a/scripts/tcd/__init__.py b/scripts/tcd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/encoding_utils.py b/scripts/tcd/encoding_utils.py similarity index 100% rename from encoding_utils.py rename to scripts/tcd/encoding_utils.py diff --git a/projection_utils.py b/scripts/tcd/projection_utils.py similarity index 98% rename from projection_utils.py rename to scripts/tcd/projection_utils.py index d159bd8671f985692ec0fd92862fdebdb8ba58d3..4f4b099eaf3bba5044c219830a1f8a8bfd338ed4 100644 --- a/projection_utils.py +++ b/scripts/tcd/projection_utils.py @@ -12,8 +12,8 @@ import numpy as np from numpy import ndarray from sympy import Symbol -from Quadrature import Quadrature -from Initial_Condition import InitialCondition +from .Quadrature import Quadrature +from .Initial_Condition import InitialCondition x = Symbol('x') diff --git a/workflows/ANN_data.smk b/workflows/ANN_data.smk index dabbf1d6b7261f13aee9b062af673c859e25f31e..9d38b0173c386a75fa5d9c1c9c61fc810fe2506a 100644 --- a/workflows/ANN_data.smk +++ b/workflows/ANN_data.smk @@ -2,7 +2,8 @@ import sys import time import numpy as np -import ANN_Data_Generator, Initial_Condition +from scripts.tcd import Initial_Condition +from scripts.tcd import ANN_Data_Generator configfile: 'config.yaml' diff --git a/workflows/ANN_training.smk b/workflows/ANN_training.smk index aba34f5dfd98d2ea6270339574a2dee9e9322d88..47c3dc217ddf692de260a6895744e17df9400b8e 100644 --- a/workflows/ANN_training.smk +++ b/workflows/ANN_training.smk @@ -1,8 +1,8 @@ import sys -import ANN_Training -from ANN_Training import * -from Plotting import plot_evaluation_results +from scripts.tcd import ANN_Training +from scripts.tcd.ANN_Training import * +from scripts.tcd.Plotting import plot_evaluation_results configfile: 'config.yaml' diff --git a/workflows/approximation.smk b/workflows/approximation.smk index ab7dda7b6fe7dada42f09247f72961f1b0b7382d..ad8d8c2bffeaa459daa4fedbcea233fc76e1b478 100644 --- a/workflows/approximation.smk +++ b/workflows/approximation.smk @@ -1,10 +1,10 @@ import sys import time -import Initial_Condition -import Quadrature -from DG_Approximation import DGScheme -from Plotting import plot_approximation_results +from scripts.tcd import Initial_Condition +from scripts.tcd import Quadrature +from scripts.tcd.DG_Approximation import DGScheme +from scripts.tcd.Plotting import plot_approximation_results configfile: 'config.yaml'