# -*- coding: utf-8 -*-
"""
@author: Soraya Terrab (sorayaterrab), Laura C. Kühle

"""
import os
import time
import numpy as np

import DG_Approximation


class TrainingDataGenerator:
    """Class for training data generator.

    Generates random training data for given initial conditions.

    Attributes
    ----------
    smooth_functions : list
        List of smooth initial/continuous conditions.
    troubled_functions : list
        List of discontinuous initial conditions.
    data_dir : str
        Path to directory in which training data is saved.

    Methods
    -------
    build_training_data(num_samples)
        Builds random training data.

    """
    def __init__(self, initial_conditions, left_bound=-1, right_bound=1,
                 balance=0.5, stencil_length=3, directory='test_data',
                 add_reconstructions=True):
        """Initializes TrainingDataGenerator.

        Parameters
        ----------
        initial_conditions : list
            List of names of initial conditions for training.
        left_bound : float, optional
            Left boundary of interval. Default: -1.
        right_bound : float, optional
            Right boundary of interval. Default: 1.
        balance: float, optional
            Ratio between smooth and discontinuous training data. Default: 0.5.
        stencil_length : int, optional
            Size of training data array. Default: 3.
        directory : str, optional
            Path to directory in which training data is saved.
            Default: 'test_data'.
        add_reconstructions: bool, optional
            Flag whether reconstructions of the middle cell are included.
            Default: True.

        """
        self._balance = balance
        self._left_bound = left_bound
        self._right_bound = right_bound
        self._add_reconstructions = add_reconstructions

        # Set stencil length
        if stencil_length % 2 == 0:
            raise ValueError('Invalid stencil length (even value): "%d"'
                             % stencil_length)
        self._stencil_length = stencil_length

        # Separate smooth and discontinuous initial conditions
        self._smooth_functions = []
        self._troubled_functions = []
        for function in initial_conditions:
            if function['function'].is_smooth():
                self._smooth_functions.append(function)
            else:
                self._troubled_functions.append(function)

        # Set directory
        self._data_dir = directory
        if not os.path.exists(self._data_dir):
            os.makedirs(self._data_dir)

    def build_training_data(self, num_samples):
        """Builds random training data.

        Creates training data consisting of random ANN input and saves it.

        Parameters
        ----------
        num_samples : int
            Number of training data samples to generate.

        Returns
        -------
        data_dict : dict
            Dictionary containing input (normalized and non-normalized) and
            output data.

        """
        tic = time.perf_counter()
        print('Calculating training data...\n')
        data_dict = self._calculate_data_set(num_samples)
        print('Finished calculating training data!')

        self._save_data(data_dict)
        toc = time.perf_counter()
        print(f'Total runtime: {toc - tic:0.4f}s')
        return data_dict

    def _calculate_data_set(self, num_samples):
        """Calculates random training data of given stencil length.

        Creates training data with a given ratio between smooth and
        discontinuous samples and fixed stencil length.

        Parameters
        ----------
        num_samples : int
            Number of training data samples to generate.

        Returns
        -------
        dict
            Dictionary containing input (normalized and non-normalized) and
            output data.

        """
        num_smooth_samples = round(num_samples * self._balance)
        smooth_input, smooth_output = self._generate_cell_data(
            num_smooth_samples, self._smooth_functions, True)

        num_troubled_samples = num_samples - num_smooth_samples
        troubled_input, troubled_output = self._generate_cell_data(
            num_troubled_samples, self._troubled_functions, False)

        # Merge Data
        input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
        output_matrix = np.concatenate((smooth_output, troubled_output),
                                       axis=0)

        # Shuffle data while keeping correct input/output matches
        order = np.random.permutation(
            num_smooth_samples + num_troubled_samples)
        input_matrix = input_matrix[order]
        output_matrix = output_matrix[order]

        # Create normalized input data
        norm_input_matrix = self._normalize_data(input_matrix)

        return {'input_data.raw': input_matrix, 'output_data': output_matrix,
                'input_data.normalized': norm_input_matrix}

    def _generate_cell_data(self, num_samples, initial_conditions, is_smooth):
        """Generates random training input and output.

        Generates random training input and output for either smooth or
        discontinuous initial conditions. For each input the output has the
        shape [is_smooth, is_troubled].

        Parameters
        ----------
        num_samples : int
            Number of training data samples to generate.
        initial_conditions : list
            List of names of initial conditions for training.
        is_smooth : bool
            Flag whether initial conditions are smooth.

        Returns
        -------
        input_data : ndarray
            Array containing input data.
        output_data : ndarray
            Array containing output data.

        """
        troubled_indicator = 'without' if is_smooth else 'with'
        print('Calculating data ' + troubled_indicator + ' troubled cells...')
        print('Samples to complete:', num_samples)
        tic = time.perf_counter()

        num_datapoints = self._stencil_length
        if self._add_reconstructions:
            num_datapoints += 2
        input_data = np.zeros((num_samples, num_datapoints))
        num_init_cond = len(initial_conditions)
        count = 0
        for i in range(num_samples):
            # Select and initialize initial condition
            function_id = i % num_init_cond
            initial_condition = initial_conditions[function_id]['function']
            initial_condition.randomize(
                initial_conditions[function_id]['config'])

            # Build random stencil of given length
            interval, centers, spacing = self._build_stencil()
            left_bound, right_bound = interval
            centers = [center[0] for center in centers]

            # Induce adjustment to capture troubled cells
            adjustment = 0 if initial_condition.is_smooth \
                else centers[self._stencil_length//2]
            initial_condition.induce_adjustment(-spacing[0]/3)

            # Calculate basis coefficients for stencil
            polynomial_degree = np.random.randint(1, high=5)
            dg_scheme = DG_Approximation.DGScheme(
                'NoDetection', polynomial_degree=polynomial_degree,
                num_grid_cells=self._stencil_length, left_bound=left_bound,
                right_bound=right_bound, quadrature='Gauss',
                quadrature_config={'num_eval_points': polynomial_degree+1})
            input_data[i] = dg_scheme.build_training_data(
                adjustment, self._stencil_length, self._add_reconstructions,
                initial_condition)

            count += 1
            if count % 1000 == 0:
                print(str(count) + ' samples completed.')

        toc = time.perf_counter()
        print('Finished calculating data ' + troubled_indicator +
              ' troubled cells!')
        print(f'Calculation time: {toc - tic:0.4f}s\n')

        # Set output data
        output_data = np.zeros((num_samples, 2))
        output_data[:, int(not is_smooth)] = np.ones(num_samples)

        return input_data, output_data

    def _build_stencil(self):
        """Builds random stencil.

        Calculates fixed number of cell centers around a random point in a
        given 1D domain.

        Returns
        -------
        interval : ndarray
            List containing left and right bound of interval.
        stencil : ndarray
            List of cell centers in stencil.
        grid_spacing : float
            Length of cell in grid.

        """
        # Select random cell length
        grid_spacing = 2 / (2 ** np.random.randint(3, high=9, size=1))

        # Pick random point between left and right bound
        point = np.random.uniform(self._left_bound, self._right_bound)

        # Adjust grid spacing if necessary for stencil creation
        while point - self._stencil_length/2 * grid_spacing < self._left_bound\
                or point + self._stencil_length/2 * \
                grid_spacing > self._right_bound:
            grid_spacing /= 2

        # Build x-point stencil
        interval = np.array([point - self._stencil_length/2 * grid_spacing,
                             point + self._stencil_length/2 * grid_spacing])
        stencil = np.array([point + factor * grid_spacing
                            for factor in range(-(self._stencil_length//2),
                                                self._stencil_length//2 + 1)])
        return interval, stencil, grid_spacing

    @staticmethod
    def _normalize_data(input_data):
        """Normalizes data.

        Parameters
        ----------
        input_data : ndarray
            Array containing input data.

        Returns
        -------
        ndarray
            Array containing normalized input data.

        """
        normalized_input_data = []
        for entry in input_data:
            max_function_value = max(max(np.absolute(entry)), 1)
            normalized_input_data.append(entry / max_function_value)
        return np.array(normalized_input_data)

    def _save_data(self, data):
        """Saves data."""
        print('Saving training data.')
        for key in data.keys():
            name = self._data_dir + '/' + key + '.npy'
            np.save(name, data[key])