# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

TODO: Rename initial conditions -> Done

"""
import numpy as np


class InitialCondition(object):
    def __init__(self, left_bound, right_bound, config):
        self.left_bound = left_bound
        self.right_bound = right_bound
        self.interval_len = self.right_bound-self.left_bound

        self.function_name = 'None'

    def get_name(self):
        return self.function_name

    def calculate(self, x):
        while x < self.left_bound:
            x = x + self.interval_len
        while x > self.right_bound:
            x = x - self.interval_len
        return self._get_point(x)

    def _get_point(self, x):
        pass


class Sine(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'Sine'

        self.factor = config.pop('factor', 2)

    def _get_point(self, x):
        return np.sin(self.factor * np.pi * x)


class Box(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'Box'

    def _get_point(self, x):
        if x < -1:
            x = x + 2
        if x > 1:
            x = x - 2
        if (x >= -0.5) & (x <= 0.5):
            return 1
        else:
            return 0


class FourPeakWave(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'FourPeakWave'

        self.alpha = 10
        self.delta = 0.005
        self.beta = np.log(2) / (36 * self.delta**2)
        self.a = 0.5
        self.z = -0.7

    def _get_point(self, x):
        if (x >= -0.8) & (x <= -0.6):
            return 1/6 * (self._G(x, self.z-self.delta) + self._G(x, self.z+self.delta) + 4 * self._G(x, self.z))
        if (x >= -0.4) & (x <= -0.2):
            return 1
        if (x >= 0) & (x <= 0.2):
            return 1 - abs(10 * (x-0.1))
        if (x >= 0.4) & (x <= 0.6):
            return 1/6 * (self._F(x, self.a-self.delta) + self._F(x, self.a+self.delta) + 4 * self._F(x, self.a))
        return 0

    def _G(self, x, z):
        return np.exp(-self.beta * (x-z)**2)

    def _F(self, x, a):
        return np.sqrt(max(1 - self.alpha**2 * (x-a)**2, 0))


class Linear(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'Linear'

        self.factor = config.pop('factor', 1)

    def _get_point(self, x):
        return self.factor * x


class LinearAbsolut(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'LinearAbsolut'

        self.factor = config.pop('factor', 1)

    def _get_point(self, x):
        return self.factor * abs(x)


class DiscontinuousConstant(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'DiscontinuousConstant'

        self.x0 = config.pop('x0', 0)
        self.left_factor = config.pop('left_factor', 1)
        self.right_factor = config.pop('right_factor', 0.5)

    def _get_point(self, x):
        return self.left_factor * (x <= self.x0) + self.right_factor * (x > self.x0)