from tokenize import String

import imageio
import numpy as np
from scipy import ndimage
from skimage import filters

from filter_enum import FILTER
from mask import Mask


class Filter(object):

    def __init__(self, image: imageio.core.util.Array, sigma: float = 0) -> None:
        """
        This class edits a image with a given filter.

        :param image: The image in which the filter should be used.
        :param sigma: Is only used if the LoG-Operator is used.
        """
        self.image = image
        self.sigma = sigma

    def use(self, option: String = FILTER.empty.value) -> np.array:
        """
        With this method it can be decided which filter should be used.

        :param option: The name of the filter to be used.
        :return: A new image
        """

        if option == FILTER.empty.value:
            return self.image
        elif option == FILTER.roberts.value:
            return filters.roberts(self.image)
        elif option == FILTER.sobel.value:
            return filters.sobel(self.image)
        elif option == FILTER.prewitt.value:
            return filters.prewitt(self.image)
        elif option == FILTER.laplacian_of_gaussian.value:
            return ndimage.gaussian_laplace(self.image, sigma=self.sigma)
        elif option == FILTER.kirsch.value:
            return self.do_convolution_with_kirsch()

    def do_convolution_with_kirsch(self) -> np.array:
        """
        This method edits a given image with the Kirsch operator.
        Since all H_k are symmetric the sum-images can be calculated only by using the first for filter-matrizes.
        This is done with:

        D_0 = I * H_0, D_1 = I * H_1, D_2 = I * H_2, D_3 = I * H_3,
        D_4 = -D_0, D_5 = -D_1, D_6 = -D_2, D_7 = -D_3

        :return: Image filtered with the Kirsch operator
        """

        H = Mask().H['K0']  # since all matrices must have the same dimensionality this step is can be assumed
        height_h, width_h = H.shape
        center_j, center_i = height_h // 2, width_h // 2
        height, width = self.image.shape
        pad_image = np.pad(self.image, (center_j, center_i), 'mean')
        new_image = np.zeros((height, width), dtype=np.float64)
        matrizes = list(Mask().H.keys())
        d = [0, 0, 0, 0, 0, 0, 0, 0]
        edited_pixels = 0
        for v in range(height):
            for u in range(width):
                R = pad_image[v:(v + center_j + 2), u:(u + center_i + 2)]
                for matrix in range(len(matrizes)):
                    new_value = np.sum(np.multiply(R, Mask().H[matrizes[matrix]]))
                    d[matrix] = new_value
                    d[matrix + 4] = - new_value
                new_image[v, u] = np.max(d)
                edited_pixels += 1
                if edited_pixels % 10000 == 0:
                    print('Done with: ' + str(edited_pixels) + ' of: ' + str(width * height) + ' Pixels.')

        return new_image