# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

d = detail coefficient (rename?)
other A (from M) = ? (Is it the same???)
A = basis_projection_left
M1 = wavelet_projection_left
phi = DG basis vector
psi = wavelet vector

TODO: Find better names for A, B, M1, and M2

"""
import numpy as np
import timeit


class UpdateScheme(object):
    def __init__(self, detector, limiter, init_cond, mesh, wave_speed, polynom_degree, num_grid_cells, final_time,
                 history_threshold, left_bound, right_bound):
        # Unpack positional arguments
        self.detector = detector
        self.limiter = limiter
        self.init_cond = init_cond
        self.mesh = mesh
        self.wave_speed = wave_speed
        self.polynom_degree = polynom_degree
        self.num_grid_cells = num_grid_cells
        self.final_time = final_time
        self.history_threshold = history_threshold
        self.left_bound = left_bound
        self.right_bound = right_bound

        self._reset()

    def get_name(self):
        return self.name

    def get_troubled_cell_history(self):
        return self.troubled_cell_history

    def get_time_history(self):
        return self.time_history

    def step(self, projection, cfl_number, current_time):
        self.original_projection = projection
        self.current_projection = projection
        self.cfl_number = cfl_number
        self.time = current_time

        self._apply_stability_method()

        self.iteration += 1

        if (self.iteration % self.history_threshold) == 0:
            self.troubled_cell_history.append(self.troubled_cells)
            self.time_history.append(self.time)

        return self.current_projection, self.troubled_cells

    def _apply_stability_method(self):
        pass

    def _reset(self):
        # Set additional necessary fixed instance variables
        self.name = 'None'
        self.interval_len = self.right_bound-self.left_bound
        self.cell_len = self.interval_len / self.num_grid_cells

        # Set matrix A
        matrix = []
        for i in range(self.polynom_degree+1):
            new_row = []
            for j in range(self.polynom_degree+1):
                new_entry = -1.0
                if (j < i) & ((i+j) % 2 == 1):
                    new_entry = 1.0
                new_row.append(new_entry*np.sqrt((i+0.5) * (j+0.5)))
            matrix.append(new_row)
        self.A = np.array(matrix)  # former: inv_mass @ np.array(matrix)

        # Set matrix B
        matrix = []
        for i in range(self.polynom_degree+1):
            new_row = []
            for j in range(self.polynom_degree+1):
                new_entry = np.sqrt((i+0.5) * (j+0.5)) * (-1.0)**i
                new_row.append(new_entry)
            matrix.append(new_row)
        self.B = np.array(matrix)  # former: inv_mass @ np.array(matrix)

        # Initialize temporary instance variables
        self.original_projection = []
        self.current_projection = []
        self.right_hand_side = []
        self.troubled_cells = []
        self.troubled_cell_history = []
        self.time_history = []
        self.cfl_number = 0
        self.time = 0
        self.iteration = 0

    def _apply_limiter(self):
        self.troubled_cells = self.detector.get_cells(self.current_projection)

        new_projection = self.current_projection.copy()
        for cell in self.troubled_cells:
            new_projection[:,  cell] = self.limiter.apply(self.current_projection, cell)

        self.current_projection = new_projection

    def _enforce_boundary_condition(self):
        self.current_projection[:, 0] = self.current_projection[:, self.num_grid_cells]
        self.current_projection[:, self.num_grid_cells+1] = self.current_projection[:, 1]


class SSPRK3(UpdateScheme):
    def __init__(self, detector, limiter, init_cond, mesh, wave_speed, polynom_degree, num_grid_cells, final_time,
                 history_threshold, left_bound, right_bound):
        super().__init__(detector, limiter, init_cond, mesh, wave_speed, polynom_degree, num_grid_cells, final_time,
                         history_threshold, left_bound, right_bound)

        # Set name of update scheme
        self.name = 'SSPRK3'

    # Override method of superclass
    def _apply_stability_method(self):
        self._apply_first_step()
        self._apply_limiter()
        self._enforce_boundary_condition()

        self._apply_second_step()
        self._apply_limiter()
        self._enforce_boundary_condition()

        self._apply_third_step()
        self._apply_limiter()
        self._enforce_boundary_condition()

    def _update_right_hand_side(self):
        # Initialize vector and set first entry to accommodate for ghost cell
        right_hand_side = [0]

        for j in range(self.num_grid_cells):
            right_hand_side.append(2*(self.A @ self.current_projection[:, j+1]
                                      + self.B @ self.current_projection[:, j]))

        # Set ghost cells to respective value
        right_hand_side[0] = right_hand_side[self.num_grid_cells]
        right_hand_side.append(right_hand_side[1])

        self.right_hand_side = np.transpose(right_hand_side)

    def _apply_first_step(self):
        self._update_right_hand_side()
        self.current_projection = self.original_projection + (self.cfl_number*self.right_hand_side)

    def _apply_second_step(self):
        self._update_right_hand_side()
        self.current_projection = 1/4 * (3 * self.original_projection
                                         + (self.current_projection + self.cfl_number*self.right_hand_side))

    def _apply_third_step(self):
        self._update_right_hand_side()
        self.current_projection = 1/3 * (self.original_projection
                                         + 2 * (self.current_projection + self.cfl_number*self.right_hand_side))