yobx.sklearn.neighbors.kneighbors_transformer#
- yobx.sklearn.neighbors.kneighbors_transformer.sklearn_kneighbors_transformer(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: KNeighborsTransformer, X: str, name: str = 'knn_transform') str[source]#
Converts a
sklearn.neighbors.KNeighborsTransformerinto ONNX.The converter produces a dense
(N, M)output tensor whereNis the number of query samples andMis the number of training samples.mode='connectivity'— entry(i, j)is1.0when training samplejis among then_neighborsnearest neighbours of query pointi, and0.0otherwise.mode='distance'— entry(i, j)is the distance from query pointito training samplejwhenjis among then_neighborsnearest neighbours, and0.0otherwise.
Note
sklearn.neighbors.KNeighborsTransformer.transform()returns a sparse CSR matrix. The ONNX graph returns the equivalent dense matrix (i.e. what you would obtain by calling.toarray()on the sparse result).Note
sklearn’s
transform()usesn_neighbors + 1neighbours internally formode='distance'to account for self-connections when transforming the training set. This converter always uses exactlyn_neighborsneighbours for both modes. The output matchessklearn.neighbors.kneighbors(X, n_neighbors)applied to the query points. For the training set withmode='distance', one of then_neighborsslots may be the query point itself (distance0.0scattered at the diagonal), which is indistinguishable from a non-neighbour entry.Supported metrics:
"sqeuclidean","euclidean","cosine","manhattan"(aliases:"cityblock","l1"),"chebyshev","minkowski". The"euclidean"and"sqeuclidean"metrics usecom.microsoft.CDistwhen that domain is registered; all other metrics use the standard-ONNX path.Full graph structure (standard-ONNX path):
X (N, F) │ └─── pairwise distances ─────────────────────────────────────► dists (N, M) │ TopK(k, axis=1, largest=0) ──► values (N, k), indices (N, k) │ zeros (1, M) ──► Expand(N, M) ──► zeros_NM (N, M) │ │ │ ScatterElements(axis=1) ─────────────────────► output (N, M)- Parameters:
g – graph builder
sts – shapes defined by scikit-learn
outputs – desired output names
estimator – a fitted
KNeighborsTransformerX – input tensor name
name – prefix for node names
- Returns:
output tensor name — dense
(N, M)matrix- Raises:
NotImplementedError – if opset < 13 or the metric is not supported