yobx.sklearn.neighbors.radiusneighbors_transformer#

yobx.sklearn.neighbors.radiusneighbors_transformer.sklearn_radius_neighbors_transformer(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: RadiusNeighborsTransformer, X: str, name: str = 'rnn_transform') str[source]#

Converts a sklearn.neighbors.RadiusNeighborsTransformer into ONNX.

The converter produces a dense (N, M) output tensor where N is the number of query samples and M is the number of training samples.

  • mode='connectivity' — entry (i, j) is 1.0 when training sample j is within the radius of query point i, and 0.0 otherwise.

  • mode='distance' — entry (i, j) is the distance from query point i to training sample j when j is within the radius, and 0.0 otherwise.

Note

sklearn.neighbors.RadiusNeighborsTransformer.transform() returns a sparse CSR matrix. The ONNX graph returns the equivalent dense matrix (i.e. what you would obtain by calling .toarray() on the sparse result).

Supported metrics: "sqeuclidean", "euclidean", "cosine", "manhattan" (aliases: "cityblock", "l1"), "chebyshev", "minkowski". The "euclidean" and "sqeuclidean" metrics use com.microsoft.CDist when that domain is registered; all other metrics use the standard-ONNX path.

Full graph structure (standard-ONNX path):

X (N, F)
  │
  └─── pairwise distances ─────────────────────────────────────► dists (N, M)
                                                                       │
                          in_radius = (dists <= radius) ──► mask (N, M) bool
                                                                       │
mode='connectivity':  Cast(float) ──────────────────────► output (N, M)
mode='distance':      Where(mask, dists, 0.0) ──────────► output (N, M)
Parameters:
  • g – graph builder

  • sts – shapes defined by scikit-learn

  • outputs – desired output names

  • estimator – a fitted RadiusNeighborsTransformer

  • X – input tensor name

  • name – prefix for node names

Returns:

output tensor name — dense (N, M) matrix

Raises:

NotImplementedError – if opset < 13 or the metric is not supported