Converting a JAX function to ONNX#

yobx.tensorflow.to_onnx() can also convert JAX functions to ONNX. Under the hood it uses jax.experimental.jax2tf.convert() to lower the JAX computation to a tensorflow.ConcreteFunction and then applies the same TF→ONNX conversion pipeline used for Keras models.

Alternatively, yobx.tensorflow.tensorflow_helper.jax_to_concrete_function() can be called explicitly to obtain the intermediate ConcreteFunction before passing it to to_onnx().

The workflow is:

  1. Write a plain JAX function (or wrap a flax/equinox model in a function).

  2. Call yobx.tensorflow.to_onnx() with a representative dummy input. The converter detects that the callable is a JAX function and automatically routes it through jax_to_concrete_function().

  3. Run the exported ONNX model with any ONNX runtime — this example uses onnxruntime.

  4. Verify that the ONNX outputs match JAX’s own outputs.

import jax
import jax.numpy as jnp
import numpy as np
import onnxruntime
from yobx.doc import plot_dot
from yobx.helpers import max_diff
from yobx.helpers.onnx_helper import pretty_onnx
from yobx.tensorflow import to_onnx
from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function

1. Simple element-wise function#

We start with the simplest possible JAX function: an element-wise sin applied to a float32 matrix. to_onnx() auto-detects that the callable is a JAX function and converts it transparently.

rng = np.random.default_rng(0)
X = rng.standard_normal((5, 4)).astype(np.float32)


def jax_sin(x):
    return jnp.sin(x)


onx_sin = to_onnx(jax_sin, (X,))

print("Opset            :", onx_sin.opset_import[0].version)
print("Number of nodes  :", len(onx_sin.graph.node))
print("Node op-types    :", [n.op_type for n in onx_sin.graph.node])
Opset            : 21
Number of nodes  : 1
Node op-types    : ['Sin']

Run and compare#

Verify that the ONNX model reproduces the JAX output.

ref_sin = onnxruntime.InferenceSession(
    onx_sin.SerializeToString(), providers=["CPUExecutionProvider"]
)
input_name = ref_sin.get_inputs()[0].name
(result_sin,) = ref_sin.run(None, {input_name: X})

expected_sin = np.asarray(jax_sin(X))
print("\nJAX  output (first row):", expected_sin[0])
print("ONNX output (first row):", result_sin[0])
assert np.allclose(expected_sin, result_sin, atol=1e-5), "Mismatch!"
print("Outputs match ✓ - ", max_diff(expected_sin, result_sin))
JAX  output (first row): [ 0.12539922 -0.13172095  0.5975344   0.10470783]
ONNX output (first row): [ 0.12539922 -0.13172095  0.5975344   0.10470783]
Outputs match ✓ -  {'abs': 5.960464477539063e-08, 'rel': 1.019410678654049e-07, 'sum': 1.7881393432617188e-07, 'n': 20.0, 'dnan': 0.0, 'argm': (2, 2)}

2. Multi-layer MLP in JAX#

A slightly more complex function: a two-layer MLP with ReLU activations whose weights are stored as JAX arrays captured in a closure.

key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)

W1 = jax.random.normal(k1, (8, 16), dtype=np.float32)
b1 = np.zeros(16, dtype=np.float32)
W2 = jax.random.normal(k2, (16, 4), dtype=np.float32)
b2 = np.zeros(4, dtype=np.float32)


def jax_mlp(x):
    h = jax.nn.relu(x @ W1 + b1)
    return h @ W2 + b2


X_mlp = rng.standard_normal((10, 8)).astype(np.float32)
onx_mlp = to_onnx(jax_mlp, (X_mlp,))

op_types = [n.op_type for n in onx_mlp.graph.node]
print("\nOp-types in the MLP graph:", op_types)
assert "MatMul" in op_types
WARNING:tensorflow:AutoGraph could not transform <bound method StackSummary.extract of <class 'traceback.StackSummary'>> and will run it as-is.
Cause: generators are not supported
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function samefile at 0x715d805249a0> and will run it as-is.
Cause: Unable to locate the source code of <function samefile at 0x715d805249a0>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Op-types in the MLP graph: ['MatMul', 'Max', 'MatMul']

Display the model.

print(pretty_onnx(onx_mlp))
opset: domain='' version=21
input: name='X:0' type=dtype('float32') shape=[10, 8]
init: name='%cst' type=float32 shape=(8, 16)
init: name='%cst_1' type=float32 shape=(16, 4)
init: name='%cst_1_dup' type=float32 shape=() -- array([0.], dtype=float32)
MatMul(X:0, %cst) -> _onx_matmul_jax2tf_arg_0:0
  Max(_onx_matmul_jax2tf_arg_0:0, %cst_1_dup) -> _onx_max_add_matmul_jax2tf_arg_0:0
    MatMul(_onx_max_add_matmul_jax2tf_arg_0:0, %cst_1) -> Identity:0
output: name='Identity:0' type='NOTENSOR' shape=None

Verify predictions on a held-out batch.

ref_mlp = onnxruntime.InferenceSession(
    onx_mlp.SerializeToString(), providers=["CPUExecutionProvider"]
)
input_name_mlp = ref_mlp.get_inputs()[0].name
(result_mlp,) = ref_mlp.run(None, {input_name_mlp: X_mlp})

expected_mlp = np.asarray(jax_mlp(X_mlp))
np.testing.assert_allclose(expected_mlp, result_mlp, atol=1e-2)
print("MLP predictions match ✓ - ", max_diff(expected_mlp, result_mlp))
MLP predictions match ✓ -  {'abs': 0.0071353912353515625, 'rel': 0.02653457824071944, 'sum': 0.08161251246929169, 'n': 40.0, 'dnan': 0.0, 'argm': (7, 0)}

3. Dynamic batch dimension#

By default to_onnx() marks axis 0 as a dynamic (symbolic) batch dimension. The converted model runs correctly for any batch size.

onx_dyn = to_onnx(jax_mlp, (X_mlp,), dynamic_shapes=({0: "batch"},))

input_shape = onx_dyn.graph.input[0].type.tensor_type.shape
batch_dim = input_shape.dim[0]
print("\nBatch dimension param  :", batch_dim.dim_param)
assert batch_dim.dim_param, "Expected a named dynamic dimension"

ref_dyn = onnxruntime.InferenceSession(
    onx_dyn.SerializeToString(), providers=["CPUExecutionProvider"]
)
input_name_dyn = ref_dyn.get_inputs()[0].name
for n in (1, 7, 20):
    X_batch = rng.standard_normal((n, 8)).astype(np.float32)
    (out,) = ref_dyn.run(None, {input_name_dyn: X_batch})
    expected = np.asarray(jax_mlp(X_batch))
    np.testing.assert_allclose(expected, out, atol=1e-2)
    print(f"Dynamic-batch model verified for batch sizes {n} ✓ - ", max_diff(expected, out))
Batch dimension param  : dim
Dynamic-batch model verified for batch sizes 1 ✓ -  {'abs': 4.76837158203125e-07, 'rel': 1.137647515021789e-06, 'sum': 1.5795230865478516e-06, 'n': 4.0, 'dnan': 0.0, 'argm': (0, 0)}
Dynamic-batch model verified for batch sizes 7 ✓ -  {'abs': 0.006735265254974365, 'rel': 0.07791697009657664, 'sum': 0.042095357552170753, 'n': 28.0, 'dnan': 0.0, 'argm': (1, 3)}
Dynamic-batch model verified for batch sizes 20 ✓ -  {'abs': 0.005284786224365234, 'rel': 0.15446957497655453, 'sum': 0.11850462667644024, 'n': 80.0, 'dnan': 0.0, 'argm': (9, 3)}

4. Explicit jax_to_concrete_function#

jax_to_concrete_function() can be called directly when you want to inspect or reuse the intermediate ConcreteFunction before exporting to ONNX.

def jax_softmax(x):
    return jax.nn.softmax(x, axis=-1)


X_cls = rng.standard_normal((6, 10)).astype(np.float32)

cf = jax_to_concrete_function(jax_softmax, (X_cls,), dynamic_shapes=({0: "batch"},))
onx_cls = to_onnx(cf, (X_cls,), dynamic_shapes=({0: "batch"},))

ref_cls = onnxruntime.InferenceSession(
    onx_cls.SerializeToString(), providers=["CPUExecutionProvider"]
)
input_name_cls = ref_cls.get_inputs()[0].name
(result_cls,) = ref_cls.run(None, {input_name_cls: X_cls})

expected_cls = np.asarray(jax_softmax(X_cls))
assert np.allclose(expected_cls, result_cls, atol=1e-5), "Softmax mismatch!"
print("Explicit jax_to_concrete_function verified ✓ - ", max_diff(expected_cls, result_cls))
Explicit jax_to_concrete_function verified ✓ -  {'abs': 4.470348358154297e-08, 'rel': 2.2931874917945164e-07, 'sum': 6.062909960746765e-07, 'n': 60.0, 'dnan': 0.0, 'argm': (0, 0)}

5. Visualize the ONNX graph#

plot_dot(onx_mlp)
plot jax to onnx

Total running time of the script: (0 minutes 11.176 seconds)

Related examples

Converting a TensorFlow/Keras model to ONNX

Converting a TensorFlow/Keras model to ONNX

Gallery generated by Sphinx-Gallery