# -*- coding: utf-8 -*-
"""Script to plot results of ANN model testing.

@author: Laura C. Kühle
"""
import sys
import time
import seaborn as sns

from tcd import Initial_Condition
from tcd import Quadrature
from tcd.Plotting import plot_approximation_results

sns.set()


def main() -> None:
    """Plot results of ANN model tests."""
    with open(str(snakemake.log[0]), 'w', encoding='utf-8') as logfile:
        sys.stdout = logfile
        sys.stderr = logfile

        tic = time.perf_counter()

        if snakemake.config['schemes'][snakemake.wildcards['scheme']][
                'detector'] == 'ArtificialNeuralNetwork':
            snakemake.params['dg_params']['detector_config']['model_state'] = \
                snakemake.params['directory'] + '/trained models/' + \
                snakemake.config['schemes'][snakemake.wildcards['scheme']][
                    'detector_config']['model_state']

        detector_dict = snakemake.params['dg_params'].copy()

        init_cond = getattr(Initial_Condition, detector_dict.pop(
            'init_cond', 'Sine'))(config=detector_dict.pop('init_config', {}))
        quadrature = getattr(Quadrature, detector_dict.pop(
            'quadrature', 'Gauss'))(detector_dict.pop('quadrature_config', {}))
        colors = detector_dict.pop('colors', None)

        plot_approximation_results(
            directory=snakemake.params['plot_dir'],
            plot_name=snakemake.wildcards['scheme'], colors=colors,
            data_file=snakemake.params.plot_dir + '/' + snakemake.wildcards[
                'scheme'],
            quadrature=quadrature, init_cond=init_cond)

        toc = time.perf_counter()
        print(f'Time: {toc - tic:0.4f}s')


if __name__ == '__main__':
    if "snakemake" in locals():
        main()
    else:
        print('Not Defined.')