# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""
import numpy as np


class InitialCondition(object):
    def __init__(self, left_bound, right_bound, config):
        self._left_bound = left_bound
        self._right_bound = right_bound
        self._interval_len = self._right_bound-self._left_bound

        self.function_name = 'None'

    def get_name(self):
        return self.function_name

    def calculate(self, x):
        while x < self._left_bound:
            x = x + self._interval_len
        while x > self._right_bound:
            x = x - self._interval_len
        return self._get_point(x)

    def _get_point(self, x):
        pass


class Sine(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'Sine'

        # Unpack necessary configurations
        self._factor = config.pop('factor', 2)

    def _get_point(self, x):
        return np.sin(self._factor * np.pi * x)


class Box(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'Box'

    def _get_point(self, x):
        if x < -1:
            x = x + 2
        if x > 1:
            x = x - 2
        if (x >= -0.5) & (x <= 0.5):
            return 1
        else:
            return 0


class FourPeakWave(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'FourPeakWave'

        # Set additional necessary parameter
        self._alpha = 10
        self._delta = 0.005
        self._beta = np.log(2) / (36 * self._delta**2)
        self._a = 0.5
        self._z = -0.7

    def _get_point(self, x):
        if (x >= -0.8) & (x <= -0.6):
            return 1/6 * (self._G(x, self._z-self._delta) + self._G(x, self._z+self._delta) + 4 * self._G(x, self._z))
        if (x >= -0.4) & (x <= -0.2):
            return 1
        if (x >= 0) & (x <= 0.2):
            return 1 - abs(10 * (x-0.1))
        if (x >= 0.4) & (x <= 0.6):
            return 1/6 * (self._F(x, self._a-self._delta) + self._F(x, self._a+self._delta) + 4 * self._F(x, self._a))
        return 0

    def _G(self, x, z):
        return np.exp(-self._beta * (x-z)**2)

    def _F(self, x, a):
        return np.sqrt(max(1 - self._alpha**2 * (x-a)**2, 0))


class Linear(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'Linear'

        # Unpack necessary configurations
        self._factor = config.pop('factor', 1)

    def _get_point(self, x):
        return self._factor * x


class LinearAbsolut(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'LinearAbsolut'

        # Unpack necessary configurations
        self._factor = config.pop('factor', 1)

    def _get_point(self, x):
        return self._factor * abs(x)


class DiscontinuousConstant(InitialCondition):
    def __init__(self, left_bound, right_bound, config):
        super().__init__(left_bound, right_bound, config)

        # Set name of function
        self.function_name = 'DiscontinuousConstant'

        # Unpack necessary configurations
        self._x0 = config.pop('x0', 0)
        self._left_factor = config.pop('left_factor', 1)
        self._right_factor = config.pop('right_factor', 0.5)

    def _get_point(self, x):
        return self._left_factor * (x <= self._x0) + self._right_factor * (x > self._x0)