# -*- coding: utf-8 -*-
"""
@author: Soraya Terrab (sorayaterrab), Laura C. Kühle

TODO: Improve '_generate_cell_data'
TODO: Extract normalization (Combine smooth and troubled before normalizing) -> Done
TODO: Adapt code to generate both normalized and non-normalized data -> Done
TODO: Improve verbose output

"""

import numpy as np
import os

import Initial_Condition
import DG_Approximation


class TrainingDataGenerator(object):
    def __init__(self, initial_conditions, left_bound=-1, right_bound=1, balance=0.5,
                 stencil_length=3, directory=None):
        self._balance = balance
        self._left_bound = left_bound
        self._right_bound = right_bound

        # Set stencil length
        if stencil_length % 2 == 0:
            raise ValueError('Invalid stencil length (even value): "%d"' % stencil_length)
        self._stencil_length = stencil_length

        # Separate smooth and discontinuous initial conditions
        self._smooth_functions = []
        self._troubled_functions = []
        for function in initial_conditions:
            if function['function'].is_smooth():
                self._smooth_functions.append(function)
            else:
                self._troubled_functions.append(function)

        # Set directory
        self._data_dir = 'test_data'
        if directory is not None:
            self._data_dir = directory
        if not os.path.exists(self._data_dir):
            os.makedirs(self._data_dir)

    def build_training_data(self, num_samples):
        print('Calculating training data...')
        data_dict = self._calculate_data_set(num_samples)
        print('Finished calculating training data!')

        self._save_data(data_dict)
        return data_dict

    def _save_data(self, data):
        for key in data.keys():
            name = self._data_dir + '/' + key + '_data.npy'
            np.save(name, data[key])

    def _calculate_data_set(self, num_samples):
        num_smooth_samples = round(num_samples * self._balance)
        smooth_input, smooth_output = self._generate_cell_data(num_smooth_samples,
                                                               self._smooth_functions, True)

        num_troubled_samples = num_samples - num_smooth_samples
        troubled_input, troubled_output = self._generate_cell_data(num_troubled_samples,
                                                                   self._troubled_functions, False)

        # Merge Data
        input_matrix = np.concatenate((smooth_input, troubled_input), axis=0)
        output_matrix = np.concatenate((smooth_output, troubled_output), axis=0)

        # Shuffle data while keeping correct input/output matches
        order = np.random.permutation(num_smooth_samples + num_troubled_samples)
        input_matrix = input_matrix[order]
        output_matrix = output_matrix[order]

        # Create normalized input data
        norm_input_matrix = self._normalize_data(input_matrix)

        return {'input': input_matrix, 'output': output_matrix,
                'normalized_input': norm_input_matrix}

    def _generate_cell_data(self, num_samples, initial_conditions, is_smooth):
        num_function_samples = num_samples//len(initial_conditions)
        function_id = 0
        input_data = np.zeros((num_samples, 5))

        count = 0
        for i in range(num_samples):
            # Pick a Function here
            initial_condition = initial_conditions[function_id]['function']
            initial_condition.randomize(initial_conditions[function_id]['config'])

            # Create basis_coefficients for function mapped onto stencil
            polynomial_degree = np.random.randint(1, high=5)

            # Calculating Cell centers for a given 1D domain with n elements, and
            # Calculating Corresponding Legendre Basis Coefficients for given polynomial_degree
            # Create stencil and basis_coefficients for smooth_function mapped onto stencil
            interval, centers, h = self._build_stencil()
            centers = [center[0] for center in centers]

            initial_condition.induce_adjustment(-h[0]/3)

            left_bound, right_bound = interval
            dg_scheme = DG_Approximation.DGScheme(
                'NoDetection', polynomial_degree=polynomial_degree,
                num_grid_cells=self._stencil_length, left_bound=left_bound, right_bound=right_bound,
                quadrature='Gauss', quadrature_config={'num_eval_points': polynomial_degree+1})

            if initial_condition.is_smooth():
                input_data[i] = dg_scheme.build_training_data(
                    0, self._stencil_length, initial_condition)
            else:
                input_data[i] = dg_scheme.build_training_data(
                    centers[self._stencil_length//2], self._stencil_length, initial_condition)

            # Update Function ID
            if (i % num_function_samples == num_function_samples - 1) \
                    and (function_id != len(initial_conditions)-1):
                function_id = function_id + 1

            count += 1
            if count % 100 == 0:
                print(str(count) + ' samples completed.')

        # Shuffle input data
        order = np.random.permutation(num_samples)
        input_data = input_data[order]

        output_data = np.zeros((num_samples, 2))
        if is_smooth:
            output_data[:, 1] = np.ones(num_samples)
        else:
            output_data[:, 0] = np.ones(num_samples)

        return input_data, output_data

    def _build_stencil(self):
        # Determining grid_spacing
        grid_spacing = 2 / (2 ** np.random.randint(3, high=9, size=1))

        # Pick a Random point between the left and right bound
        point = np.random.random(1) * (self._right_bound-self._left_bound) + self._left_bound

        # Ensure Bounds of x-point stencil are within the left and right bound
        while point - self._stencil_length/2 * grid_spacing < self._left_bound\
                or point + self._stencil_length/2 * grid_spacing > self._right_bound:
            grid_spacing = grid_spacing / 2

        # x-point stencil
        interval = np.array([point - self._stencil_length/2 * grid_spacing,
                             point + self._stencil_length/2 * grid_spacing])
        stencil = np.array([point + factor * grid_spacing
                            for factor in range(-(self._stencil_length//2),
                                                self._stencil_length//2 + 1)])
        return interval, stencil, grid_spacing

    @staticmethod
    def _normalize_data(input_data):
        normalized_input_data = input_data
        for i in range(len(input_data)):
            max_function_value = max(max(np.absolute(input_data[i])), 1)
            normalized_input_data[i] = input_data[i] / max_function_value
        return normalized_input_data


# Get Training/Validation Datasets
np.random.seed(1234)
# generator = TrainingDataGenerator(functions, left_bound=boundary[0], right_bound=boundary[1])
# generator = TrainingDataGenerator(functions, left_bound=boundary[0], right_bound=boundary[1])

sample_number = 1000
# data_1 = generator.build_training_data(sample_number, 0)
# data_2 = generator.build_training_data(sample_number, 1)