import math
from tokenize import String
from typing import List

import cv2 as cv
import imageio
import matplotlib.pyplot as plt
import numpy as np


class ORB(object):

    def __init__(self, origin_path: String, edit_path: String, k: int = 0):
        """
        This class is the editor which edits a given image and stores a new one.

        :param origin_path: The original image which should be edited.
        :param edit_path: The edited image which should be stored.
        :param k: Is used fr K-Means.
        """
        self.image_path = origin_path
        self.edit_path = edit_path
        self.k = k

    @staticmethod
    def euclidean(vector1: List, vector2: List) -> float:
        """
        This method calculates the euclidean distance.

        :param vector1: Vector as list
        :param vector2: Vector as list
        :return: Euclidean distance between vector1 and vector2.
        """
        dist = [(a - b) ** 2 for a, b in zip(vector1, vector2)]
        dist = math.sqrt(sum(dist))
        return dist

    def pairwise_arg_min(self, X: List, Y: List) -> np.ndarray:
        """
        This method returns a list of all pairwise distances from X to Y.

        :param X: Vector with features
        :param Y: Centroids
        :return: List of all pairwise distances from X to Y.
        """
        return np.asarray([np.argmin([self.euclidean(x, y) for y in Y]) for x in X])

    def find_clusters(self, X, n_clusters, rseed=2):
        # 1. Randomly choose clusters
        rng = np.random.RandomState(rseed)
        i = rng.permutation(X.shape[0])[:n_clusters]
        centers = X[i]
        while True:
            # 2a. Assign labels based on closest center
            labels = self.pairwise_arg_min(X, centers)

            # 2b. Find new centers from means of points
            new_centers = np.array([X[labels == i].mean(0)
                                    for i in range(n_clusters)])

            # 2c. Check for convergence
            if np.all(centers == new_centers):
                break
            centers = new_centers

        return centers, labels

    def get_keypoints(self) -> None:
        """
        This method does K-Means with the ORB Keypoints.
        :return: None
        """
        img = imageio.imread(uri=self.image_path)
        plt.imshow(img)

        # Initiate ORB detector
        orb = cv.ORB_create(nfeatures=1000, scoreType=cv.ORB_FAST_SCORE)

        # find the keypoints with ORB
        kp = orb.detect(img, None)

        # compute the descriptors with ORB
        kp, des = orb.compute(img, kp)
        key_points = [k.pt for k in kp]

        X = np.array([list(x) for x in key_points])

        centers, labels = self.find_clusters(X, self.k)
        plt.scatter(X[:, 0], X[:, 1], marker='.', s=10, c=labels, cmap='viridis')

        plt.scatter(centers[:, 0], centers[:, 1], marker='+', color='red')
        plt.axis('off')
        plt.savefig(self.edit_path, dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close()