Source code for experimental_experiment.skl.helpers
importtorch
[docs]defflatnonzero(x):"Similar to :func:`numpy.flatnonzero`"returntorch.nonzero(torch.reshape(x,(-1,)),as_tuple=True)[0]
defcheck_non_negative(array,whom):assertarray.min()>=0,f"{whom} has passed a negative value."def_num_samples(x):returnlen(x)def_get_weights(dist,weights):"""Get the weights from an array of distances and a parameter ``weights``. Assume weights have already been validated. Parameters ---------- dist : ndarray The input distances. weights : {'uniform', 'distance'}, callable or None The kind of weighting used. Returns ------- weights_arr : array of the same shape as ``dist`` If ``weights == 'uniform'``, then returns None. """ifweightsin(None,"uniform"):returnNoneifweights=="distance":# if user attempts to classify a point that was zero distance from one# or more training points, those training points are weighted as 1.0# and the other points as 0.0dist=1.0/distinf_mask=torch.isinf(dist)inf_row=torch.any(inf_mask,axis=1)dist[inf_row]=inf_mask[inf_row]returndistifcallable(weights):returnweights(dist)def_return_float_dtype(X,Y):""" 1. If dtype of X and Y is float32, then dtype float32 is returned. 2. Else dtype float is returned. """Y_dtype=X.dtypeifYisNoneelseY.dtypedtype=X.dtypeifX.dtype==Y_dtype==torch.float32elsetorch.float64X=X.to(dtype)Y=Y.to(dtype)returnX,Y,dtypedef_check_estimator_name(estimator):ifestimatorisnotNone:ifisinstance(estimator,str):returnestimatorelse:returnestimator.__class__.__name__returnNone