Skip to content
Snippets Groups Projects
Select Git revision
  • 6e16334a8edc1c73a390407095a6134a9759647a
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

evaluate_unified_datasets.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    orb.py 2.96 KiB
    import math
    from tokenize import String
    from typing import List
    
    import cv2 as cv
    import imageio
    import matplotlib.pyplot as plt
    import numpy as np
    
    
    class ORB(object):
    
        def __init__(self, origin_path: String, edit_path: String, k: int = 0):
            """
            This class is the editor which edits a given image and stores a new one.
    
            :param origin_path: The original image which should be edited.
            :param edit_path: The edited image which should be stored.
            :param k: Is used fr K-Means.
            """
            self.image_path = origin_path
            self.edit_path = edit_path
            self.k = k
    
        @staticmethod
        def euclidean(vector1: List, vector2: List) -> float:
            """
            This method calculates the euclidean distance.
    
            :param vector1: Vector as list
            :param vector2: Vector as list
            :return: Euclidean distance between vector1 and vector2.
            """
            dist = [(a - b) ** 2 for a, b in zip(vector1, vector2)]
            dist = math.sqrt(sum(dist))
            return dist
    
        def pairwise_arg_min(self, X: List, Y: List) -> np.ndarray:
            """
            This method returns a list of all pairwise distances from X to Y.
    
            :param X: Vector with features
            :param Y: Centroids
            :return: List of all pairwise distances from X to Y.
            """
            return np.asarray([np.argmin([self.euclidean(x, y) for y in Y]) for x in X])
    
        def find_clusters(self, X, n_clusters, rseed=2):
            # 1. Randomly choose clusters
            rng = np.random.RandomState(rseed)
            i = rng.permutation(X.shape[0])[:n_clusters]
            centers = X[i]
            while True:
                # 2a. Assign labels based on closest center
                labels = self.pairwise_arg_min(X, centers)
    
                # 2b. Find new centers from means of points
                new_centers = np.array([X[labels == i].mean(0)
                                        for i in range(n_clusters)])
    
                # 2c. Check for convergence
                if np.all(centers == new_centers):
                    break
                centers = new_centers
    
            return centers, labels
    
        def get_keypoints(self) -> None:
            """
            This method does K-Means with the ORB Keypoints.
            :return: None
            """
            img = imageio.imread(uri=self.image_path)
            plt.imshow(img)
    
            # Initiate ORB detector
            orb = cv.ORB_create(nfeatures=1000, scoreType=cv.ORB_FAST_SCORE)
    
            # find the keypoints with ORB
            kp = orb.detect(img, None)
    
            # compute the descriptors with ORB
            kp, des = orb.compute(img, kp)
            key_points = [k.pt for k in kp]
    
            X = np.array([list(x) for x in key_points])
    
            centers, labels = self.find_clusters(X, self.k)
            plt.scatter(X[:, 0], X[:, 1], marker='.', s=10, c=labels, cmap='viridis')
    
            plt.scatter(centers[:, 0], centers[:, 1], marker='+', color='red')
            plt.axis('off')
            plt.savefig(self.edit_path, dpi=300, bbox_inches='tight', pad_inches=0)
            plt.close()