Custom Converter#
The built-in converter registry covers estimators that ship with scikit-learn. When you train a custom estimator — or want to override how a built-in estimator is translated — you can supply your own converter without touching the package source.
There are two ways:
Ad-hoc via the
extra_convertersparameter ofto_onnx— useful for one-off conversions or during development.Permanent via the
register_sklearn_converterdecorator — the right choice once a converter is stable and reusable.
Writing a converter function#
A converter follows the same contract as all built-in ones:
(g, sts, outputs, estimator, *input_names, name) → output_name(s)
Parameter |
Description |
|---|---|
|
|
|
|
|
|
|
The fitted scikit-learn object. |
|
One positional |
|
String prefix for unique node-name generation. |
Ad-hoc conversion with extra_converters#
Pass a {EstimatorClass: converter_function} mapping to the
extra_converters keyword argument. Entries in that mapping take
priority over built-in converters, so you can also override an
existing converter this way.
The example below defines a custom ScaleByConstant transformer and
its corresponding ONNX converter, then converts an instance to ONNX and
validates the result numerically.
<<<
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from yobx.sklearn import to_onnx
from yobx.helpers.onnx_helper import pretty_onnx
# ── 1. Custom estimator ────────────────────────────────────────────
class ScaleByConstant(TransformerMixin, BaseEstimator):
"""Multiplies every feature by a fixed scalar constant."""
def __init__(self, scale=2.0):
self.scale = scale
def fit(self, X, y=None):
return self
def transform(self, X):
return X * self.scale
# ── 2. Converter function ──────────────────────────────────────────
def convert_scale_by_constant(g, sts, outputs, estimator, X, name="scale"):
"""Emits a single ``Mul`` node: output = X * estimator.scale."""
scale = np.array([estimator.scale], dtype=np.float32)
result = g.op.Mul(X, scale, name=name, outputs=outputs)
return result
rng = np.random.default_rng(0)
X = rng.standard_normal((5, 3)).astype(np.float32)
est = ScaleByConstant(scale=3.0).fit(X)
onx = to_onnx(est, (X,), extra_converters={ScaleByConstant: convert_scale_by_constant})
print(pretty_onnx(onx))
>>>
opset: domain='' version=21
opset: domain='ai.onnx.ml' version=5
input: name='X' type=dtype('float32') shape=['batch', 3]
init: name='init1_s1_' type=float32 shape=(1,) -- array([3.], dtype=float32)-- Opset.make_node.1/Small
Mul(X, init1_s1_) -> Y
output: name='Y' type='NOTENSOR' shape=None
Validate numerically#
<<<
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from yobx.sklearn import to_onnx
from yobx.reference import ExtendedReferenceEvaluator
class ScaleByConstant(TransformerMixin, BaseEstimator):
def __init__(self, scale=2.0):
self.scale = scale
def fit(self, X, y=None):
return self
def transform(self, X):
return X * self.scale
def convert_scale_by_constant(g, sts, outputs, estimator, X, name="scale"):
scale = np.array([estimator.scale], dtype=np.float32)
result = g.op.Mul(X, scale, name=name, outputs=outputs)
return result
rng = np.random.default_rng(0)
X = rng.standard_normal((5, 3)).astype(np.float32)
est = ScaleByConstant(scale=3.0).fit(X)
onx = to_onnx(est, (X,), extra_converters={ScaleByConstant: convert_scale_by_constant})
ref = ExtendedReferenceEvaluator(onx)
onnx_output = ref.run(None, {"X": X})[0]
sklearn_output = est.transform(X).astype(np.float32)
print("max absolute difference:", np.abs(onnx_output - sklearn_output).max())
>>>
max absolute difference: 0.0
Overriding a built-in converter#
Because extra_converters entries take priority, you can also replace
the converter for a built-in estimator. The snippet below replaces the
standard sklearn.preprocessing.StandardScaler converter with a
trivial identity (just to illustrate the override mechanism):
<<<
import numpy as np
from sklearn.preprocessing import StandardScaler
from yobx.sklearn import to_onnx
from yobx.helpers.onnx_helper import pretty_onnx
def identity_scaler(g, sts, outputs, estimator, X, name="scaler"):
"""Pass-through: return the input unchanged."""
result = g.op.Identity(X, name=name, outputs=outputs)
return result
rng = np.random.default_rng(1)
X = rng.standard_normal((4, 2)).astype(np.float32)
ss = StandardScaler().fit(X)
# The custom converter overrides the built-in one
onx = to_onnx(ss, (X,), extra_converters={StandardScaler: identity_scaler})
print(pretty_onnx(onx))
>>>
opset: domain='' version=21
opset: domain='ai.onnx.ml' version=5
input: name='X' type=dtype('float32') shape=['batch', 2]
Identity(X) -> x
output: name='x' type='NOTENSOR' shape=None
Permanent registration#
Once your converter is stable, promote it from an ad-hoc function to a
first-class entry in the registry by using the
register_sklearn_converter decorator. This
means you no longer have to pass extra_converters at every call site:
# myproject/onnx_converters.py
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from yobx.sklearn.register import register_sklearn_converter
from yobx.typing import GraphBuilderExtendedProtocol
from yobx.xbuilder import GraphBuilder
class ScaleByConstant(TransformerMixin, BaseEstimator):
def __init__(self, scale=2.0):
self.scale = scale
def fit(self, X, y=None):
return self
def transform(self, X):
return X * self.scale
@register_sklearn_converter(ScaleByConstant)
def convert_scale_by_constant(
g: GraphBuilderExtendedProtocol,
sts: dict,
outputs: list,
estimator: ScaleByConstant,
X: str,
name: str = "scale",
) -> str:
scale = np.array([estimator.scale], dtype=np.float32)
result = g.op.Mul(X, scale, name=name, outputs=outputs)
return result
Once this module is imported the converter is available globally and
to_onnx will use it automatically:
import myproject.onnx_converters # registers the converter
from yobx.sklearn import to_onnx
onx = to_onnx(ScaleByConstant(scale=3.0).fit(X), (X,))
# no extra_converters needed
See also
Sklearn Converter — overview of the converter registry, the built-in converters, and how to add a new converter to the package itself.