Alternative GraphBuilderExtendedProtocol#
GraphBuilderExtendedProtocol
is the interface that every graph builder used by the yobx.sklearn
converters must satisfy. The package ships with three concrete
implementations and makes it easy to add more:
GraphBuilder— the default; builds graphs using onnx protobuf objects with built-in optimisation passes.OnnxScriptGraphBuilder— delegates graph construction to theonnxscriptIR.SpoxGraphBuilder— delegates graph construction to the spox library.
Why provide alternatives?#
Keeping the builders behind a protocol rather than inheriting from a single base class means that any third-party library can supply its own builder. Some reasons for doing so:
Better IDE / type support — spox and
onnxscriptboth use strongly-typed, opset-versioned Python functions so mistakes are caught statically rather than at runtime.Validation on construction — spox validates the graph structure incrementally, so type errors surface when a node is added rather than at export time.
Integration into an existing IR pipeline — if the rest of the workflow already works with
onnxscript’sir.Model, it is more convenient to accumulate nodes there directly and avoid a round-trip throughonnx.ModelProto.
Using OnnxScriptGraphBuilder#
OnnxScriptGraphBuilder
is a bridge that builds an onnxscript ir.Model internally
while presenting the same string-based API to converters.
<<<
import numpy as np
import onnx
from sklearn.preprocessing import StandardScaler
from yobx.sklearn import to_onnx
from yobx.builder.onnxscript import OnnxScriptGraphBuilder
from yobx.helpers.onnx_helper import pretty_onnx
rng = np.random.default_rng(0)
X = rng.standard_normal((10, 4)).astype(np.float32)
scaler = StandardScaler().fit(X)
model = to_onnx(scaler, (X,), builder_cls=OnnxScriptGraphBuilder)
print(pretty_onnx(model))
>>>
opset: domain='' version=21
opset: domain='ai.onnx.ml' version=5
input: name='X' type=dtype('float32') shape=['batch', 4]
init: name='init_' type=float32 shape=(4,) -- array([-0.448, 0.052, -0.093, 0.247], dtype=float32)
init: name='init_2' type=float32 shape=(4,) -- array([0.774, 0.641, 0.825, 0.728], dtype=float32)
Sub(X, init_) -> Sub
Div(Sub, init_2) -> x
output: name='x' type=dtype('float32') shape=['batch', 4]
Using SpoxGraphBuilder#
SpoxGraphBuilder is a bridge
that delegates every operator call to the matching spox opset
module, providing static type-checking and incremental graph validation.
The only change relative to the default workflow is passing
builder_cls=SpoxGraphBuilder to yobx.sklearn.to_onnx():
<<<
import numpy as np
import onnx
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from yobx.sklearn import to_onnx
from yobx.builder.spox import SpoxGraphBuilder
from yobx.helpers.onnx_helper import pretty_onnx
rng = np.random.default_rng(0)
X = rng.standard_normal((80, 4)).astype(np.float32)
y = (X[:, 0] + X[:, 1] > 0).astype(int)
pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())])
pipe.fit(X, y)
model = to_onnx(pipe, (X[:1],), builder_cls=SpoxGraphBuilder)
print(pretty_onnx(model))
>>>
opset: domain='' version=21
input: name='X' type=dtype('float32') shape=['batch', 4]
Constant(value=[0, 1]) -> Constant_0_output
Constant(value=[1.0]) -> Constant_1_output
Constant(value=[-0.040599...) -> Constant_2_output
Sub(X, Constant_2_output) -> Sub_0_C
Constant(value=[0.9541545...) -> Constant_3_output
Div(Sub_0_C, Constant_3_output) -> Div_0_C
Constant(value=[[2.229245...) -> Constant_4_output
Constant(value=[0.2121082...) -> Constant_5_output
Gemm(Div_0_C, Constant_4_output, Constant_5_output, alpha=1.00, beta=1.00, transA=0, transB=1) -> Gemm_0_Y
Sigmoid(Gemm_0_Y) -> Sigmoid_0_Y
Sub(Constant_1_output, Sigmoid_0_Y) -> Sub_1_C
Concat(Sub_1_C, Sigmoid_0_Y, axis=-1) -> Concat_0_concat_result
ArgMax(Concat_0_concat_result, axis=1, keepdims=0, select_last_index=0) -> ArgMax_0_reduced
Cast(ArgMax_0_reduced, saturate=1, to=7) -> Cast_0_output
Gather(Constant_0_output, Cast_0_output, axis=0) -> Gather_0_output
Identity(Gather_0_output) -> label
Identity(Concat_0_concat_result) -> probabilities
output: name='label' type=dtype('int64') shape=['batch']
output: name='probabilities' type=dtype('float32') shape=['batch', 2]
See also
Expected API — the full list of methods and attributes every builder must expose.
SpoxGraphBuilder — a
complete, production-quality alternative implementation backed by
spox.
OnnxScriptGraphBuilder
— a complete alternative backed by the onnxscript IR.