# -*- coding: utf-8 -*-
"""
@author: Soraya Terrab (sorayaterrab), Laura C. Kühle

"""
import os
import time
import numpy as np

from DG_Approximation import do_initial_projection
from projection_utils import Mesh
from Quadrature import Gauss
from Basis_Function import OrthonormalLegendre


basis_list = [OrthonormalLegendre(pol_deg) for pol_deg in range(5)]
quadrature_list = [Gauss({'num_nodes': pol_deg+1}) for pol_deg in range(5)]


class TrainingDataGenerator:
    """Class for training data generator.

    Generates random training data for given initial conditions.

    Methods
    -------
    build_training_data(num_samples)
        Builds random training data.

    """
    def __init__(self, left_bound=-1, right_bound=1, stencil_length=3):
        """Initializes TrainingDataGenerator.

        Parameters
        ----------
        left_bound : float, optional
            Left boundary of interval. Default: -1.
        right_bound : float, optional
            Right boundary of interval. Default: 1.
        stencil_length : int, optional
            Size of training data array. Default: 3.

        """
        self._left_bound = left_bound
        self._right_bound = right_bound

        # Set stencil length
        if stencil_length % 2 == 0:
            raise ValueError('Invalid stencil length (even value): "%d"'
                             % stencil_length)
        self._stencil_length = stencil_length

    def build_training_data(self, initial_conditions, num_samples, balance=0.5,
                            directory='test_data', add_reconstructions=True):
        """Builds random training data.

        Creates training data consisting of random ANN input and saves it.

        Parameters
        ----------
        initial_conditions : list
            List of names of initial conditions for training.
        num_samples : int
            Number of training data samples to generate.
        balance : float, optional
            Ratio between smooth and discontinuous training data. Default: 0.5.
        directory : str, optional
            Path to directory in which training data is saved.
            Default: 'test_data'.
        add_reconstructions : bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        Returns
        -------
        data_dict : dict
            Dictionary containing input (normalized and non-normalized) and
            output data.

        """
        tic = time.perf_counter()
        print('Calculating training data...\n')
        data_dict = self._calculate_data_set(initial_conditions,
                                             num_samples, balance,
                                             add_reconstructions)
        print('Finished calculating training data!')

        self._save_data(directory=directory, data=data_dict)
        toc = time.perf_counter()
        print(f'Total runtime: {toc - tic:0.4f}s')
        return data_dict

    def _calculate_data_set(self, initial_conditions, num_samples, balance,
                            add_reconstructions):
        """Calculates random training data of given stencil length.

        Creates training data with a given ratio between smooth and
        discontinuous samples and fixed stencil length.

        Parameters
        ----------
        initial_conditions : list
            List of names of initial conditions for training.
        num_samples : int
            Number of training data samples to generate.
        balance : float
            Ratio between smooth and discontinuous training data.
        add_reconstructions : bool
            Flag whether reconstructions of the middle cell are included.

        Returns
        -------
        dict
            Dictionary containing input (normalized and non-normalized) and
            output data.

        """
        # print(type(initial_conditions))
        # Separate smooth and discontinuous initial conditions
        smooth_functions = []
        troubled_functions = []
        for function in initial_conditions:
            if function['function'].is_smooth():
                smooth_functions.append(function)
            else:
                troubled_functions.append(function)

        num_smooth_samples = round(num_samples * balance)
        smooth_input, smooth_output = self._generate_cell_data(
            num_smooth_samples, smooth_functions, add_reconstructions, True)

        num_troubled_samples = num_samples - num_smooth_samples
        troubled_input, troubled_output = self._generate_cell_data(
            num_troubled_samples, troubled_functions, add_reconstructions,
            False)

        # Merge Data
        input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
        output_matrix = np.concatenate((smooth_output, troubled_output),
                                       axis=0)

        # Shuffle data while keeping correct input/output matches
        order = np.random.permutation(
            num_smooth_samples + num_troubled_samples)
        input_matrix = input_matrix[order]
        output_matrix = output_matrix[order]

        # Create normalized input data
        norm_input_matrix = self._normalize_data(input_matrix)

        return {'input_data.raw': input_matrix, 'output_data': output_matrix,
                'input_data.normalized': norm_input_matrix}

    def _generate_cell_data(self, num_samples, initial_conditions,
                            add_reconstructions, is_smooth):
        """Generates random training input and output.

        Generates random training input and output for either smooth or
        discontinuous initial conditions. For each input the output has the
        shape [is_smooth, is_troubled].

        Parameters
        ----------
        num_samples : int
            Number of training data samples to generate.
        initial_conditions : list
            List of names of initial conditions for training.
        add_reconstructions : bool
            Flag whether reconstructions of the middle cell are included.
        is_smooth : bool
            Flag whether initial conditions are smooth.

        Returns
        -------
        input_data : ndarray
            Array containing input data.
        output_data : ndarray
            Array containing output data.

        """
        # print(type(initial_conditions))
        troubled_indicator = 'without' if is_smooth else 'with'
        print('Calculating data ' + troubled_indicator + ' troubled cells...')
        print('Samples to complete:', num_samples)
        tic = time.perf_counter()

        num_datapoints = self._stencil_length
        if add_reconstructions:
            num_datapoints += 2
        input_data = np.zeros((num_samples, num_datapoints))
        num_init_cond = len(initial_conditions)
        count = 0
        for i in range(num_samples):
            # Select and initialize initial condition
            function_id = i % num_init_cond
            initial_condition = initial_conditions[function_id]['function']
            initial_condition.randomize(
                initial_conditions[function_id]['config'])

            # Build random stencil of given length
            interval, centers, spacing = self._build_stencil()
            left_bound, right_bound = interval
            centers = [center[0] for center in centers]

            # Induce adjustment to capture troubled cells
            adjustment = 0 if initial_condition.is_smooth \
                else centers[self._stencil_length//2]
            initial_condition.induce_adjustment(-spacing[0]/3)

            # Calculate basis coefficients for stencil
            polynomial_degree = np.random.randint(1, high=5)

            mesh = Mesh(num_grid_cells=self._stencil_length, num_ghost_cells=2,
                        left_bound=left_bound, right_bound=right_bound)
            projection = do_initial_projection(
                initial_condition=initial_condition, mesh=mesh,
                basis=basis_list[polynomial_degree],
                quadrature=quadrature_list[polynomial_degree],
                adjustment=adjustment)

            input_data[i] = basis_list[
                polynomial_degree].calculate_cell_average(
                projection=projection[:, 1:-1],
                stencil_length=self._stencil_length,
                add_reconstructions=add_reconstructions)

            count += 1
            if count % 1000 == 0:
                print(str(count) + ' samples completed.')

        toc = time.perf_counter()
        print('Finished calculating data ' + troubled_indicator +
              ' troubled cells!')
        print(f'Calculation time: {toc - tic:0.4f}s\n')

        # Set output data
        output_data = np.zeros((num_samples, 2))
        output_data[:, int(not is_smooth)] = np.ones(num_samples)

        return input_data, output_data

    def _build_stencil(self):
        """Builds random stencil.

        Calculates fixed number of cell centers around a random point in a
        given 1D domain.

        Returns
        -------
        interval : ndarray
            List containing left and right bound of interval.
        stencil : ndarray
            List of cell centers in stencil.
        grid_spacing : float
            Length of cell in grid.

        """
        # Select random cell length
        grid_spacing = 2 / (2 ** np.random.randint(3, high=9, size=1))

        # Pick random point between left and right bound
        point = np.random.uniform(self._left_bound, self._right_bound)

        # Adjust grid spacing if necessary for stencil creation
        while point - self._stencil_length/2 * grid_spacing < self._left_bound\
                or point + self._stencil_length/2 * \
                grid_spacing > self._right_bound:
            grid_spacing /= 2

        # Build x-point stencil
        interval = np.array([point - self._stencil_length/2 * grid_spacing,
                             point + self._stencil_length/2 * grid_spacing])
        stencil = np.array([point + factor * grid_spacing
                            for factor in range(-(self._stencil_length//2),
                                                self._stencil_length//2 + 1)])
        return interval, stencil, grid_spacing

    @staticmethod
    def _normalize_data(input_data):
        """Normalizes data.

        Parameters
        ----------
        input_data : ndarray
            Array containing input data.

        Returns
        -------
        ndarray
            Array containing normalized input data.

        """
        normalized_input_data = []
        for entry in input_data:
            max_function_value = max(max(np.absolute(entry)), 1)
            normalized_input_data.append(entry / max_function_value)
        return np.array(normalized_input_data)

    @staticmethod
    def _save_data(directory, data):
        """Saves data."""
        # Set directory
        if not os.path.exists(directory):
            os.makedirs(directory)

        print('Saving training data.')
        for key in data.keys():
            name = directory + '/' + key + '.npy'
            np.save(name, data[key])