yobx.sklearn.neighbors.nearest_centroid#
- yobx.sklearn.neighbors.nearest_centroid.sklearn_nearest_centroid(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: NearestCentroid, X: str, name: str = 'nearest_centroid') str | Tuple[str, str][source]#
Converts a
sklearn.neighbors.NearestCentroidinto ONNX. The converter is registered only isscikit-learn>=1.8. It is not tested before that version.Reproduces both
predict()andpredict_proba().Uniform-prior labels path (all
class_prior_values are equal):sklearn assigns each sample to the nearest centroid using raw pairwise distances (no feature normalisation):
X (N, F) │ └─── pairwise distances ──────────────────────────────────► dists (N, C) │ ArgMin(axis=1) ──────────────► idx (N,) │ Gather(classes_) ──────────────────────────► label (N,)Non-uniform-prior labels path:
The discriminant score for class
c(eq. 18.2, ESL 2nd ed.) is:score_c = -dist(X_norm, centroid_norm_c)² + 2 * log(prior_c)
where
X_normandcentroid_norm_care divided element-wise bywithin_class_std_dev_(features with zero std are left unchanged):X (N, F) │ ├── Div(within_class_std_dev_) ──────────────────────────► X_norm (N, F) │ │ │ pairwise distances ──────────────► dists (N, C) │ │ │ Mul(dists, dists) ────────► sq_dists (N, C) │ │ │ -sq_dists + 2*log(class_prior_) ───────────────► scores (N, C) │ │ └───────────────────────────────────── ArgMax(axis=1) ────────────► idx (N,) │ Gather(classes_) ──────────────────────────► label (N,)Probabilities path (always uses discriminant scores):
sklearn’s
predict_proba()always goes through_decision_function, which applies the discriminant score formulation above regardless of prior uniformity. The probabilities are the softmax of those scores:scores (N, C) │ ├── ReduceMax(axis=1, keepdims=1) ────────────────────────► max_s (N, 1) │ │ ├── Sub(max_s) ──────────────────────────────────────────► shifted (N, C) │ │ ├── Exp ────────────────────────────────────────────────► exp_s (N, C) │ │ └── Div(ReduceSum(exp_s, axis=1, keepdims=1)) ──────────► proba (N, C)
Supported metrics:
"euclidean"and"manhattan".- Parameters:
g – graph builder
sts – shapes defined by scikit-learn
outputs – desired output names;
outputs[0]receives the predicted labels andoutputs[1](if present) receives the class probabilitiesestimator – a fitted
NearestCentroidX – input tensor name
name – prefix names for the added nodes
- Returns:
predicted label tensor (and optionally probability tensor as second output)