# -*- coding: utf-8 -*-
"""
@author: Soraya Terrab (sorayaterrab), Laura C. Kühle

"""
import os
import time
import numpy as np

from DG_Approximation import do_initial_projection
from projection_utils import Mesh
from Quadrature import Gauss
from Basis_Function import OrthonormalLegendre


class TrainingDataGenerator:
    """Class for training data generator.

    Generates random training data for given initial conditions.

    Attributes
    ----------
    basis_list : list
        List of basis instances for degree 1 to 4.
    quadrature_list : list
        List of Gauss quadrature instances for degree 2 to 5.
    mesh_list : list
        List of Mesh instances for 2**(3 to 8) cells.

    Methods
    -------
    build_training_data(num_samples)
        Builds random training data.

    """
    def __init__(self, left_bound=-1, right_bound=1, stencil_length=3):
        """Initializes TrainingDataGenerator.

        Parameters
        ----------
        left_bound : float, optional
            Left boundary of interval. Default: -1.
        right_bound : float, optional
            Right boundary of interval. Default: 1.
        stencil_length : int, optional
            Size of training data array. Default: 3.

        """
        self._basis_list = [OrthonormalLegendre(pol_deg)
                            for pol_deg in range(5)]
        self._quadrature_list = [Gauss({'num_nodes': pol_deg+1})
                                 for pol_deg in range(5)]
        self._mesh_list = [Mesh(left_bound=left_bound, right_bound=right_bound,
                                num_ghost_cells=0, num_grid_cells=2**exp)
                           for exp in range(3, 9)]

        # Set stencil length
        if stencil_length % 2 == 0:
            raise ValueError('Invalid stencil length (even value): "%d"'
                             % stencil_length)
        self._stencil_length = stencil_length

    def build_training_data(self, initial_conditions, num_samples, balance=0.5,
                            directory='test_data', add_reconstructions=True):
        """Builds random training data.

        Creates training data consisting of random ANN input and saves it.

        Parameters
        ----------
        initial_conditions : list
            List of names of initial conditions for training.
        num_samples : int
            Number of training data samples to generate.
        balance : float, optional
            Ratio between smooth and discontinuous training data. Default: 0.5.
        directory : str, optional
            Path to directory in which training data is saved.
            Default: 'test_data'.
        add_reconstructions : bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        Returns
        -------
        data_dict : dict
            Dictionary containing input (normalized and non-normalized) and
            output data.

        """
        tic = time.perf_counter()
        print('Calculating training data...\n')
        data_dict = self._calculate_data_set(initial_conditions,
                                             num_samples, balance,
                                             add_reconstructions)
        print('Finished calculating training data!')

        self._save_data(directory=directory, data=data_dict)
        toc = time.perf_counter()
        print(f'Total runtime: {toc - tic:0.4f}s')
        return data_dict

    def _calculate_data_set(self, initial_conditions, num_samples, balance,
                            add_reconstructions):
        """Calculates random training data of given stencil length.

        Creates training data with a given ratio between smooth and
        discontinuous samples and fixed stencil length.

        Parameters
        ----------
        initial_conditions : list
            List of names of initial conditions for training.
        num_samples : int
            Number of training data samples to generate.
        balance : float
            Ratio between smooth and discontinuous training data.
        add_reconstructions : bool
            Flag whether reconstructions of the middle cell are included.

        Returns
        -------
        dict
            Dictionary containing input (normalized and non-normalized) and
            output data.

        """
        # print(type(initial_conditions))
        # Separate smooth and discontinuous initial conditions
        smooth_functions = []
        troubled_functions = []
        for function in initial_conditions:
            if function['function'].is_smooth():
                smooth_functions.append(function)
            else:
                troubled_functions.append(function)

        num_smooth_samples = round(num_samples * balance)
        smooth_input, smooth_output = self._generate_cell_data(
            num_smooth_samples, smooth_functions, add_reconstructions, True)

        num_troubled_samples = num_samples - num_smooth_samples
        troubled_input, troubled_output = self._generate_cell_data(
            num_troubled_samples, troubled_functions, add_reconstructions,
            False)

        # Merge Data
        input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
        output_matrix = np.concatenate((smooth_output, troubled_output),
                                       axis=0)

        # Shuffle data while keeping correct input/output matches
        order = np.random.permutation(
            num_smooth_samples + num_troubled_samples)
        input_matrix = input_matrix[order]
        output_matrix = output_matrix[order]

        # Create normalized input data
        norm_input_matrix = self._normalize_data(input_matrix)
        # print(input_matrix)
        # print(norm_input_matrix)

        return {'input_data.raw': input_matrix, 'output_data': output_matrix,
                'input_data.normalized': norm_input_matrix}

    def _generate_cell_data(self, num_samples, initial_conditions,
                            add_reconstructions, is_smooth):
        """Generates random training input and output.

        Generates random training input and output for either smooth or
        discontinuous initial conditions. For each input the output has the
        shape [is_smooth, is_troubled].

        Parameters
        ----------
        num_samples : int
            Number of training data samples to generate.
        initial_conditions : list
            List of names of initial conditions for training.
        add_reconstructions : bool
            Flag whether reconstructions of the middle cell are included.
        is_smooth : bool
            Flag whether initial conditions are smooth.

        Returns
        -------
        input_data : ndarray
            Array containing input data.
        output_data : ndarray
            Array containing output data.

        """
        # print(type(initial_conditions))
        troubled_indicator = 'without' if is_smooth else 'with'
        print('Calculating data ' + troubled_indicator + ' troubled cells...')
        print('Samples to complete:', num_samples)
        tic = time.perf_counter()

        num_datapoints = self._stencil_length
        if add_reconstructions:
            num_datapoints += 2
        input_data = np.zeros((num_samples, num_datapoints))
        num_init_cond = len(initial_conditions)
        count = 0
        for i in range(num_samples):
            # Select and initialize initial condition
            function_id = i % num_init_cond
            initial_condition = initial_conditions[function_id]['function']
            initial_condition.randomize(
                initial_conditions[function_id]['config'])

            # Build mesh for random stencil of given length
            mesh = self._mesh_list[int(np.random.randint(
                3, high=9, size=1))-3].random_stencil(self._stencil_length)

            # Induce adjustment to capture troubled cells
            adjustment = 0 if initial_condition.is_smooth() \
                else mesh.non_ghost_cells[self._stencil_length//2]
            initial_condition.induce_adjustment(-mesh.cell_len/3)
            # print(initial_condition.is_smooth())
            # print(mesh.interval_len, mesh.non_ghost_cells, mesh.cell_len)
            # print(adjustment, -mesh.cell_len/3)
            # print()

            # Calculate basis coefficients for stencil
            polynomial_degree = np.random.randint(1, high=5)
            projection = do_initial_projection(
                initial_condition=initial_condition, mesh=mesh,
                basis=self._basis_list[polynomial_degree],
                quadrature=self._quadrature_list[polynomial_degree],
                adjustment=adjustment)
            # print(projection)
            input_data[i] = self._basis_list[
                polynomial_degree].calculate_cell_average(
                projection=projection[:, 1:-1],
                stencil_length=self._stencil_length,
                add_reconstructions=add_reconstructions)

            count += 1
            if count % 1000 == 0:
                print(str(count) + ' samples completed.')

        toc = time.perf_counter()
        print('Finished calculating data ' + troubled_indicator +
              ' troubled cells!')
        print(f'Calculation time: {toc - tic:0.4f}s\n')

        # Set output data
        output_data = np.zeros((num_samples, 2))
        output_data[:, int(not is_smooth)] = np.ones(num_samples)

        return input_data, output_data

    @staticmethod
    def _normalize_data(input_data):
        """Normalizes data.

        Parameters
        ----------
        input_data : ndarray
            Array containing input data.

        Returns
        -------
        ndarray
            Array containing normalized input data.

        """
        normalized_input_data = []
        for entry in input_data:
            max_function_value = max(max(np.absolute(entry)), 1)
            normalized_input_data.append(entry / max_function_value)
        return np.array(normalized_input_data)

    @staticmethod
    def _save_data(directory, data):
        """Saves data."""
        # Set directory
        if not os.path.exists(directory):
            os.makedirs(directory)

        print('Saving training data.')
        for key in data.keys():
            name = directory + '/' + key + '.npy'
            np.save(name, data[key])