yobx.sklearn.sksurv.ensemble#

yobx.sklearn.sksurv.ensemble.sklearn_random_survival_forest(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: RandomSurvivalForest, X: str, name: str = 'random_survival_forest') str[source]#

Converts a sksurv.ensemble.RandomSurvivalForest into ONNX.

Algorithm overview

RandomSurvivalForest is an ensemble of SurvivalTree estimators. The risk-score prediction for a sample is the average of the individual tree predictions:

predict(x) = mean_t [ sum_{j: is_event_time[j]} CHF_t(T_j | x) ]

where CHF_t(T_j | x) is the cumulative hazard function value stored in the leaf reached by sample x in tree t at the j-th unique training time.

Because each tree’s contribution reduces to a scalar per leaf, the forest is equivalent to a standard TreeEnsembleRegressor once the leaf weights are pre-computed as the CHF sum over observed event times.

Graph structure:

X ──TreeEnsemble[Regressor]──► risk_scores (N, 1)

When ai.onnx.ml opset 5 (or later) is available, the unified TreeEnsemble operator is used (leaf weights pre-divided by n_estimators, aggregate_function=SUM); otherwise the legacy TreeEnsembleRegressor operator is emitted with aggregate_function="AVERAGE".

Parameters:
  • g – the graph builder to add nodes to

  • sts – shapes and types defined by scikit-learn

  • outputs – desired output tensor names

  • estimator – a fitted RandomSurvivalForest

  • X – name of the input tensor

  • name – prefix used for names of nodes added by this converter

Returns:

output tensor name