Code source de mlstatpy.ml.kppv_laesa

import random
import numpy
from .kppv import NuagePoints


[docs] class NuagePointsLaesa(NuagePoints): """ Implémente l'algorithme des plus proches voisins, version :ref:`LAESA <space_metric_algo_laesa_prime>`. """ def __init__(self, nb_pivots): """ Construit la classe @param nb_pivots number of pivots """ NuagePoints.__init__(self) self.nb_pivots = nb_pivots
[docs] def fit(self, X, y=None): """ Follows sklearn API. @param X training set @param y labels """ self.nuage = X self.labels = y self.selection_pivots(self.nb_pivots)
[docs] def selection_pivots(self, nb): """ Sélectionne *nb* pivots aléatoirements. @param nb nombre de pivots """ nb = min(nb, self.nuage.shape[0]) if nb == 1: self.pivots = [2] else: self.pivots = set() while len(self.pivots) < nb: i = random.randint(0, self.nuage.shape[0] - 1) if i not in self.pivots: self.pivots.add(i) self.pivots = list(sorted(self.pivots)) # on calcule aussi la distance de chaque éléments au pivots self.dist = numpy.zeros((self.nuage.shape[0], len(self.pivots))) for i in range(self.nuage.shape[0]): for j in range(len(self.pivots)): self.dist[i, j] = self.distance( self.nuage[i, :], self.nuage[self.pivots[j], :] )
[docs] def ppv(self, obj): """ Retourne l'élément le plus proche de obj et sa distance avec obj, utilise la sélection à l'aide pivots @param obj object @return ``tuple(distance, index)`` """ # initialisation dp = [ (self.distance(obj, self.nuage[p, :]), p, i) for i, p in enumerate(self.pivots) ] # pivots le plus proche dm, im, _ = min(dp) # améliorations for i in range(self.nuage.shape[0]): # on regarde si un pivot permet d'éliminer l'élément i calcul = True for d, _p, ip in dp: delta = abs(d - self.dist[i, ip]) if delta > dm: calcul = False break # dans le cas contraire on calcule la distance if calcul: d = self.distance(obj, self.nuage[i, :]) if d < dm: dm = d im = i return dm, im