TensorFlow / JAX Export to ONNX#
yobx.tensorflow.to_onnx() converts a TensorFlow/Keras
model — or a JAX function — into an onnx.ModelProto. The implementation is a
proof-of-concept that traces the model with
tf.function / get_concrete_function and then
converts each TF operation in the resulting computation graph to its
ONNX equivalent via a registry of op-level converters.
High-level workflow#
Keras model / layer
│
▼
to_onnx() ← builds TensorSpecs, calls get_concrete_function
│
▼
ConcreteFunction (TF computation graph)
│
▼
_convert_concrete_function()
│ for every op in the graph …
▼
op converter ← emits ONNX node(s) via GraphBuilder.op.*
│
▼
GraphBuilder.to_onnx() ← validates and returns ModelProto
The steps in detail:
to_onnxaccepts the Keras model, a tuple of representative numpy inputs (used to infer dtypes and shapes), and optionalinput_names/dynamic_shapes.A
tf.TensorSpecis built for every input. By default, the batch axis (axis 0) is made dynamic; passdynamic_shapesto customise which axes are dynamic.get_concrete_function()traces the model with those specs, yielding atensorflow.ConcreteFunctionwhosegraphexposes every individual TF operation in execution order.A fresh
GraphBuilderis created and_convert_concrete_function()walks the op list:Captured variables (model weights) are seeded as ONNX initializers.
Placeholder ops that correspond to real inputs are registered via
make_tensor_input.Every other op is dispatched to a registered converter (or to an entry in
extra_converters).
The ONNX outputs are declared with
make_tensor_output.GraphBuilder.to_onnxfinalises and returns theonnx.ModelProto.
Quick example (Keras)#
import numpy as np
import tensorflow as tf
from yobx.tensorflow import to_onnx
model = tf.keras.Sequential([
tf.keras.layers.Dense(8, activation="relu", input_shape=(4,)),
tf.keras.layers.Dense(2),
])
X = np.random.rand(5, 4).astype(np.float32)
onx = to_onnx(model, (X,))
JAX support#
to_onnx() also accepts plain JAX functions.
When it detects that the callable is a JAX function (TF tracing raises a
TypeError about abstract arrays), it automatically falls back to
jax_to_concrete_function(), which
uses jax2tf to lower the JAX computation to a
ConcreteFunction before applying the standard TF→ONNX
pipeline.
import jax.numpy as jnp
import numpy as np
from yobx.tensorflow import to_onnx
def jax_fn(x):
return jnp.sin(x)
X = np.random.rand(5, 4).astype(np.float32)
onx = to_onnx(jax_fn, (X,))
You can also call
jax_to_concrete_function()
explicitly when you want to inspect or reuse the intermediate
ConcreteFunction:
from yobx.tensorflow import to_onnx
from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function
import numpy as np
cf = jax_to_concrete_function(jax_fn, (X,), dynamic_shapes=({0: "batch"},))
onx = to_onnx(cf, (X,), dynamic_shapes=({0: "batch"},))
See Converting a JAX function to ONNX for a complete gallery example.
Converter registry#
The registry is a module-level dictionary
TF_OP_CONVERTERS: Dict[str, Callable] defined in
yobx.tensorflow.register. Keys are TF op-type strings
(e.g. "MatMul", "Relu"); values are converter callables.
Registering a converter#
Use the register_tf_op_converter decorator. Pass a
single op-type string or a tuple of strings:
from yobx.tensorflow.register import register_tf_op_converter
from yobx.xbuilder import GraphBuilder
@register_tf_op_converter(("MyOp", "MyOpV2"))
def convert_my_op(g: GraphBuilder, sts: dict, outputs: list, op) -> str:
return g.op.SomeOnnxOp(op.inputs[0].name, outputs=outputs, name=op.name)
The decorator raises TypeError if any of the given op-type
strings is already present in the global registry (including duplicates
within the same tuple, since each string is registered before the next
one is checked).
Looking up a converter#
get_tf_op_converter
accepts an op-type string and returns the registered callable, or
None if none is found.
get_tf_op_converters
returns a copy of the full registry dictionary.
Converter function signature#
Every op converter follows the same contract:
(g, sts, outputs, op[, verbose]) → output_name
Parameter |
Description |
|---|---|
|
|
|
|
|
|
|
A |
|
Optional verbosity level (default |
The function should return the name of the primary output tensor.
Input tensor names are obtained via op.inputs[i].name and attribute
values via op.get_attr("attr_name").
Supported ops#
See Supported TF Ops for the full list of built-in TF op converters, generated automatically from the live registry.
Dynamic shapes#
By default to_onnx marks axis 0 of
every input as dynamic (unnamed batch dimension). To control which axes
are dynamic, pass dynamic_shapes — a tuple of one Dict[int, str]
per input where keys are axis indices and values are symbolic dimension
names:
# axis 0 named "batch", axis 1 fixed
onx = to_onnx(model, (X,), dynamic_shapes=({0: "batch"},))
Custom op converters#
The extra_converters parameter of to_onnx accepts a mapping from TF op-type string to
converter function. Entries here take priority over the built-in
registry, making it easy to override or extend coverage without
modifying the package.
import numpy as np
import tensorflow as tf
from yobx.tensorflow import to_onnx
model = tf.keras.Sequential([
tf.keras.layers.Dense(4, activation="relu", input_shape=(3,))
])
X = np.random.rand(5, 3).astype(np.float32)
def custom_relu(g, sts, outputs, op):
"""Replace Relu with a Clip(0, 1) to saturate outputs at 1."""
import numpy as np
return g.op.Clip(
op.inputs[0].name,
np.array(0.0, dtype=np.float32),
np.array(1.0, dtype=np.float32),
outputs=outputs[:1],
name=op.name,
)
onx = to_onnx(model, (X,), extra_converters={"Relu": custom_relu})
Adding a new built-in converter#
To extend the built-in op coverage:
Create a new file under
yobx/tensorflow/ops/(e.g.yobx/tensorflow/ops/reduce.py).Implement a converter function following the signature above.
Decorate it with
@register_tf_op_converter("ReduceSum")(or a tuple for multiple op types).Import the new module inside the
register()function inyobx/tensorflow/ops/__init__.py.
# yobx/tensorflow/ops/reduce.py
import numpy as np
from onnx import TensorProto
from ..register import register_tf_op_converter
from ...xbuilder import GraphBuilder
@register_tf_op_converter("Sum")
def convert_reduce_sum(g: GraphBuilder, sts: dict, outputs: list, op) -> str:
"""TF ``Sum`` → ONNX ``ReduceSum``."""
keepdims = int(op.get_attr("keep_dims"))
# TF may pass a 0-D scalar for single-axis reductions; ONNX requires 1-D.
axes_i64 = g.op.Cast(op.inputs[1].name, to=TensorProto.INT64, name=f"{op.name}_cast")
axes = g.op.Reshape(axes_i64, np.array([-1], dtype=np.int64), name=f"{op.name}_axes")
return g.op.ReduceSum(
op.inputs[0].name, axes, keepdims=keepdims, outputs=outputs, name=op.name
)