yobx.sklearn.multioutput.classifier_chain#

yobx.sklearn.multioutput.classifier_chain.sklearn_classifier_chain(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: ClassifierChain, X: str, name: str = 'classifier_chain') str | Tuple[str, str][source]#

Converts a sklearn.multioutput.ClassifierChain into ONNX.

Each sub-estimator in the chain predicts one binary target using the original features augmented with the binary predictions from all preceding steps (in chain order). After all steps the per-step predictions and probabilities are reordered to match the original target column order.

Graph structure (labels only, identity order):

X ──[est 0 converter]──► label_0 ── Cast(float)──Reshape(N,1) ──┐ pred_0_col
│                                                               │
Concat(X, pred_0_col) ──[est 1 converter]──► label_1 ──Cast──Reshape──┐ pred_1_col
│                                                                       │
...                         +-------------------------------------------+
        Concat(axis=1) ─────+──────► labels (N, n_targets)

When the chain order_ is not the identity permutation, the concatenated predictions (in chain order) are reordered via Gather using the inverse permutation so that the output columns match the original target order.

Graph structure (with probabilities):

The probability for class 1 is extracted from each sub-estimator’s (N, 2) probability output, reshaped to (N, 1), concatenated into (N, n_targets) in chain order, then reordered in the same way as the labels.

Parameters:
  • g – the graph builder to add nodes to

  • sts – shapes and types defined by scikit-learn

  • outputs – desired output tensor names (label, or label + probabilities)

  • estimator – a fitted ClassifierChain

  • X – name of the input tensor

  • name – prefix used for names of nodes added by this converter

Returns:

label tensor name, or tuple (label, probabilities)

Raises:

NotImplementedError – when probabilities are requested but sub-estimators do not expose predict_proba()