.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_tensorflow/plot_jax_to_onnx.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_tensorflow_plot_jax_to_onnx.py: .. _l-plot-jax-to-onnx: Converting a JAX function to ONNX ================================== :func:`yobx.tensorflow.to_onnx` can also convert :epkg:`JAX` functions to ONNX. Under the hood it uses :func:`jax.experimental.jax2tf.convert` to lower the JAX computation to a :class:`tensorflow.ConcreteFunction` and then applies the same TF→ONNX conversion pipeline used for Keras models. Alternatively, :func:`yobx.tensorflow.tensorflow_helper.jax_to_concrete_function` can be called explicitly to obtain the intermediate :class:`~tensorflow.ConcreteFunction` before passing it to :func:`~yobx.tensorflow.to_onnx`. The workflow is: 1. **Write** a plain JAX function (or wrap a :mod:`flax`/:mod:`equinox` model in a function). 2. Call :func:`yobx.tensorflow.to_onnx` with a representative *dummy input*. The converter detects that the callable is a JAX function and automatically routes it through :func:`~yobx.tensorflow.tensorflow_helper.jax_to_concrete_function`. 3. **Run** the exported ONNX model with any ONNX runtime — this example uses :epkg:`onnxruntime`. 4. **Verify** that the ONNX outputs match JAX's own outputs. .. GENERATED FROM PYTHON SOURCE LINES 30-40 .. code-block:: Python import jax import jax.numpy as jnp import numpy as np import onnxruntime from yobx.doc import plot_dot from yobx.helpers import max_diff from yobx.helpers.onnx_helper import pretty_onnx from yobx.tensorflow import to_onnx from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function .. GENERATED FROM PYTHON SOURCE LINES 41-47 1. Simple element-wise function -------------------------------- We start with the simplest possible JAX function: an element-wise ``sin`` applied to a float32 matrix. :func:`to_onnx` auto-detects that the callable is a JAX function and converts it transparently. .. GENERATED FROM PYTHON SOURCE LINES 47-62 .. code-block:: Python rng = np.random.default_rng(0) X = rng.standard_normal((5, 4)).astype(np.float32) def jax_sin(x): return jnp.sin(x) onx_sin = to_onnx(jax_sin, (X,)) print("Opset :", onx_sin.opset_import[0].version) print("Number of nodes :", len(onx_sin.graph.node)) print("Node op-types :", [n.op_type for n in onx_sin.graph.node]) .. rst-class:: sphx-glr-script-out .. code-block:: none Opset : 21 Number of nodes : 1 Node op-types : ['Sin'] .. GENERATED FROM PYTHON SOURCE LINES 63-67 Run and compare ~~~~~~~~~~~~~~~~ Verify that the ONNX model reproduces the JAX output. .. GENERATED FROM PYTHON SOURCE LINES 67-80 .. code-block:: Python ref_sin = onnxruntime.InferenceSession( onx_sin.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name = ref_sin.get_inputs()[0].name (result_sin,) = ref_sin.run(None, {input_name: X}) expected_sin = np.asarray(jax_sin(X)) print("\nJAX output (first row):", expected_sin[0]) print("ONNX output (first row):", result_sin[0]) assert np.allclose(expected_sin, result_sin, atol=1e-5), "Mismatch!" print("Outputs match ✓ - ", max_diff(expected_sin, result_sin)) .. rst-class:: sphx-glr-script-out .. code-block:: none JAX output (first row): [ 0.12539922 -0.13172095 0.5975344 0.10470783] ONNX output (first row): [ 0.12539922 -0.13172095 0.5975344 0.10470783] Outputs match ✓ - {'abs': 5.960464477539063e-08, 'rel': 1.019410678654049e-07, 'sum': 1.7881393432617188e-07, 'n': 20.0, 'dnan': 0.0, 'argm': (2, 2)} .. GENERATED FROM PYTHON SOURCE LINES 81-86 2. Multi-layer MLP in JAX -------------------------- A slightly more complex function: a two-layer MLP with ReLU activations whose weights are stored as JAX arrays captured in a closure. .. GENERATED FROM PYTHON SOURCE LINES 86-108 .. code-block:: Python key = jax.random.PRNGKey(42) k1, k2 = jax.random.split(key) W1 = jax.random.normal(k1, (8, 16), dtype=np.float32) b1 = np.zeros(16, dtype=np.float32) W2 = jax.random.normal(k2, (16, 4), dtype=np.float32) b2 = np.zeros(4, dtype=np.float32) def jax_mlp(x): h = jax.nn.relu(x @ W1 + b1) return h @ W2 + b2 X_mlp = rng.standard_normal((10, 8)).astype(np.float32) onx_mlp = to_onnx(jax_mlp, (X_mlp,)) op_types = [n.op_type for n in onx_mlp.graph.node] print("\nOp-types in the MLP graph:", op_types) assert "MatMul" in op_types .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING:tensorflow:AutoGraph could not transform > and will run it as-is. Cause: generators are not supported To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING:tensorflow:AutoGraph could not transform and will run it as-is. Cause: Unable to locate the source code of . Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert Op-types in the MLP graph: ['MatMul', 'Max', 'MatMul'] .. GENERATED FROM PYTHON SOURCE LINES 109-110 Display the model. .. GENERATED FROM PYTHON SOURCE LINES 110-112 .. code-block:: Python print(pretty_onnx(onx_mlp)) .. rst-class:: sphx-glr-script-out .. code-block:: none opset: domain='' version=21 input: name='X:0' type=dtype('float32') shape=[10, 8] init: name='%cst' type=float32 shape=(8, 16) init: name='%cst_1' type=float32 shape=(16, 4) init: name='%cst_1_dup' type=float32 shape=() -- array([0.], dtype=float32) MatMul(X:0, %cst) -> _onx_matmul_jax2tf_arg_0:0 Max(_onx_matmul_jax2tf_arg_0:0, %cst_1_dup) -> _onx_max_add_matmul_jax2tf_arg_0:0 MatMul(_onx_max_add_matmul_jax2tf_arg_0:0, %cst_1) -> Identity:0 output: name='Identity:0' type='NOTENSOR' shape=None .. GENERATED FROM PYTHON SOURCE LINES 113-114 Verify predictions on a held-out batch. .. GENERATED FROM PYTHON SOURCE LINES 114-125 .. code-block:: Python ref_mlp = onnxruntime.InferenceSession( onx_mlp.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name_mlp = ref_mlp.get_inputs()[0].name (result_mlp,) = ref_mlp.run(None, {input_name_mlp: X_mlp}) expected_mlp = np.asarray(jax_mlp(X_mlp)) np.testing.assert_allclose(expected_mlp, result_mlp, atol=1e-2) print("MLP predictions match ✓ - ", max_diff(expected_mlp, result_mlp)) .. rst-class:: sphx-glr-script-out .. code-block:: none MLP predictions match ✓ - {'abs': 0.0071353912353515625, 'rel': 0.02653457824071944, 'sum': 0.08161251246929169, 'n': 40.0, 'dnan': 0.0, 'argm': (7, 0)} .. GENERATED FROM PYTHON SOURCE LINES 126-131 3. Dynamic batch dimension --------------------------- By default :func:`to_onnx` marks axis 0 as a dynamic (symbolic) batch dimension. The converted model runs correctly for any batch size. .. GENERATED FROM PYTHON SOURCE LINES 131-150 .. code-block:: Python onx_dyn = to_onnx(jax_mlp, (X_mlp,), dynamic_shapes=({0: "batch"},)) input_shape = onx_dyn.graph.input[0].type.tensor_type.shape batch_dim = input_shape.dim[0] print("\nBatch dimension param :", batch_dim.dim_param) assert batch_dim.dim_param, "Expected a named dynamic dimension" ref_dyn = onnxruntime.InferenceSession( onx_dyn.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name_dyn = ref_dyn.get_inputs()[0].name for n in (1, 7, 20): X_batch = rng.standard_normal((n, 8)).astype(np.float32) (out,) = ref_dyn.run(None, {input_name_dyn: X_batch}) expected = np.asarray(jax_mlp(X_batch)) np.testing.assert_allclose(expected, out, atol=1e-2) print(f"Dynamic-batch model verified for batch sizes {n} ✓ - ", max_diff(expected, out)) .. rst-class:: sphx-glr-script-out .. code-block:: none Batch dimension param : dim Dynamic-batch model verified for batch sizes 1 ✓ - {'abs': 4.76837158203125e-07, 'rel': 1.137647515021789e-06, 'sum': 1.5795230865478516e-06, 'n': 4.0, 'dnan': 0.0, 'argm': (0, 0)} Dynamic-batch model verified for batch sizes 7 ✓ - {'abs': 0.006735265254974365, 'rel': 0.07791697009657664, 'sum': 0.042095357552170753, 'n': 28.0, 'dnan': 0.0, 'argm': (1, 3)} Dynamic-batch model verified for batch sizes 20 ✓ - {'abs': 0.005284786224365234, 'rel': 0.15446957497655453, 'sum': 0.11850462667644024, 'n': 80.0, 'dnan': 0.0, 'argm': (9, 3)} .. GENERATED FROM PYTHON SOURCE LINES 151-157 4. Explicit jax_to_concrete_function --------------------------------------- :func:`~yobx.tensorflow.tensorflow_helper.jax_to_concrete_function` can be called directly when you want to inspect or reuse the intermediate :class:`~tensorflow.ConcreteFunction` before exporting to ONNX. .. GENERATED FROM PYTHON SOURCE LINES 157-178 .. code-block:: Python def jax_softmax(x): return jax.nn.softmax(x, axis=-1) X_cls = rng.standard_normal((6, 10)).astype(np.float32) cf = jax_to_concrete_function(jax_softmax, (X_cls,), dynamic_shapes=({0: "batch"},)) onx_cls = to_onnx(cf, (X_cls,), dynamic_shapes=({0: "batch"},)) ref_cls = onnxruntime.InferenceSession( onx_cls.SerializeToString(), providers=["CPUExecutionProvider"] ) input_name_cls = ref_cls.get_inputs()[0].name (result_cls,) = ref_cls.run(None, {input_name_cls: X_cls}) expected_cls = np.asarray(jax_softmax(X_cls)) assert np.allclose(expected_cls, result_cls, atol=1e-5), "Softmax mismatch!" print("Explicit jax_to_concrete_function verified ✓ - ", max_diff(expected_cls, result_cls)) .. rst-class:: sphx-glr-script-out .. code-block:: none Explicit jax_to_concrete_function verified ✓ - {'abs': 4.470348358154297e-08, 'rel': 2.2931874917945164e-07, 'sum': 6.062909960746765e-07, 'n': 60.0, 'dnan': 0.0, 'argm': (0, 0)} .. GENERATED FROM PYTHON SOURCE LINES 179-182 5. Visualize the ONNX graph ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 182-183 .. code-block:: Python plot_dot(onx_mlp) .. image-sg:: /auto_examples_tensorflow/images/sphx_glr_plot_jax_to_onnx_001.png :alt: plot jax to onnx :srcset: /auto_examples_tensorflow/images/sphx_glr_plot_jax_to_onnx_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 11.176 seconds) .. _sphx_glr_download_auto_examples_tensorflow_plot_jax_to_onnx.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_jax_to_onnx.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_jax_to_onnx.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_jax_to_onnx.zip ` .. include:: plot_jax_to_onnx.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_