# -*- coding: utf-8 -*-
"""
@author: Laura C. Kühle

"""
from __future__ import annotations
import math
import numpy as np
from numpy import ndarray
from typing import Tuple
from functools import cache


class Mesh:
    """Class for mesh.

    Each cell is characterized by its center.

    Attributes
    ----------
    mode : str
        Mode for mesh use. Either 'training' or 'evaluation'.
    num_cells : int
        Number of cells in the mesh (ghost cells notwithstanding). Usually
        exponential of 2.
    num_ghost_cells : int
        Number of ghost cells on both sides of the mesh, respectively.
        Either 0 during training or 1 during evaluation.
    bounds : Tuple[float, float]
        Left and right boundary of the mesh interval.
    interval_len : float
        Length of the mesh interval.
    cell_len : float
        Length of a mesh cell.
    cells : ndarray
        Array of cell centers in mesh.

    Methods
    -------
    create_data_dict()
        Return dictionary with data necessary to construct mesh.
    random_stencil(stencil_len)
        Return random stencil.

    """

    def __init__(self, num_cells: int, left_bound: float, right_bound: float,
                 training_data_mode: bool = False) -> None:
        """Initialize Mesh.

        Parameters
        ----------
        num_cells : int
            Number of cells in the mesh (ghost cells notwithstanding). Has
            to be an exponential of 2.
        left_bound : float
            Left boundary of the mesh interval.
        right_bound : float
            Right boundary of the mesh interval.
        training_data_mode : bool, optional
            Flag indicating whether the mesh is used for training data
            generation. Default: False.

        Raises
        ------
        ValueError
            If number of cells is not exponential of 2.

        """
        self._num_cells = num_cells
        self._mode = 'training' if training_data_mode else 'evaluation'
        self._num_ghost_cells = 0
        if not training_data_mode:
            self._num_ghost_cells = 1
            if not math.log(self._num_cells, 2).is_integer():
                raise ValueError('The number of cells in the mesh has to be '
                                 'an exponential of 2')
        self._left_bound = left_bound
        self._right_bound = right_bound

    @property
    def mode(self) -> str:
        """Return mode ('training' or 'evaluation')."""
        return self._mode

    @property
    def num_cells(self) -> int:
        """Return number of mesh cells."""
        return self._num_cells

    @property
    def num_ghost_cells(self) -> int:
        """Return number of ghost mesh cells."""
        return self._num_ghost_cells

    @property
    def bounds(self) -> Tuple[float, float]:
        """Return left and right boundary of the mesh interval."""
        return self._left_bound, self._right_bound

    @property
    def interval_len(self) -> float:
        """Return the length of the mesh interval."""
        return self._right_bound - self._left_bound

    @property
    def cell_len(self) -> float:
        """Return the length of a mesh cell."""
        return self.interval_len/self.num_cells

    @property
    @cache
    def cells(self) -> ndarray:
        """Return the cell centers of the mesh (including ghost cells)."""
        return np.arange(
            self._left_bound - (self._num_ghost_cells*2-1)/2*self.cell_len,
            self._right_bound + (self._num_ghost_cells*2+1)/2*self.cell_len,
            self.cell_len)

    @property
    def non_ghost_cells(self) -> ndarray:
        """Return the cell centers of the mesh (excluding ghost cells)."""
        return self.cells[self._num_ghost_cells:
                          len(self.cells)-self._num_ghost_cells]

    def create_data_dict(self) -> dict:
        """Return dictionary with data necessary to construct mesh."""
        return {'num_cells': self._num_cells,
                'left_bound': self._left_bound,
                'right_bound': self._right_bound}

    def random_stencil(self, stencil_len: int) -> Mesh:
        """Return random stencil.

        Build mesh with given number of cell centers around a random point
        in the underlying interval.

        Returns
        -------
        Mesh object
            Mesh of given size around random point.

        """
        # Pick random point between left and right bound
        point = np.random.uniform(self._left_bound, self._right_bound)

        # Adjust mesh spacing to be within interval if necessary
        mesh_spacing = self.cell_len
        max_spacing = 2/stencil_len*min(point-self._left_bound,
                                        self._right_bound-point)
        while mesh_spacing > max_spacing:
            mesh_spacing /= 2

        # Return new mesh instance
        return Mesh(left_bound=point - stencil_len/2 * mesh_spacing,
                    right_bound=point + stencil_len/2 * mesh_spacing,
                    num_cells=stencil_len, training_data_mode=True)