[docs]classNuagePointsLaesa(NuagePoints):""" Implémente l'algorithme des plus proches voisins, version :ref:`LAESA <space_metric_algo_laesa_prime>`. """def__init__(self,nb_pivots):""" Construit la classe @param nb_pivots number of pivots """NuagePoints.__init__(self)self.nb_pivots=nb_pivots
[docs]deffit(self,X,y=None):""" Follows sklearn API. @param X training set @param y labels """self.nuage=Xself.labels=yself.selection_pivots(self.nb_pivots)
[docs]defselection_pivots(self,nb):""" Sélectionne *nb* pivots aléatoirements. @param nb nombre de pivots """nb=min(nb,self.nuage.shape[0])ifnb==1:self.pivots=[2]else:self.pivots=set()whilelen(self.pivots)<nb:i=random.randint(0,self.nuage.shape[0]-1)ifinotinself.pivots:self.pivots.add(i)self.pivots=list(sorted(self.pivots))# on calcule aussi la distance de chaque éléments au pivotsself.dist=numpy.zeros((self.nuage.shape[0],len(self.pivots)))foriinrange(self.nuage.shape[0]):forjinrange(len(self.pivots)):self.dist[i,j]=self.distance(self.nuage[i,:],self.nuage[self.pivots[j],:])
[docs]defppv(self,obj):""" Retourne l'élément le plus proche de obj et sa distance avec obj, utilise la sélection à l'aide pivots @param obj object @return ``tuple(distance, index)`` """# initialisationdp=[(self.distance(obj,self.nuage[p,:]),p,i)fori,pinenumerate(self.pivots)]# pivots le plus prochedm,im,_=min(dp)# améliorationsforiinrange(self.nuage.shape[0]):# on regarde si un pivot permet d'éliminer l'élément icalcul=Trueford,_p,ipindp:delta=abs(d-self.dist[i,ip])ifdelta>dm:calcul=Falsebreak# dans le cas contraire on calcule la distanceifcalcul:d=self.distance(obj,self.nuage[i,:])ifd<dm:dm=dim=ireturndm,im