.. _l-design-tensorflow-converter: ================================ TensorFlow / JAX Export to ONNX ================================ .. toctree:: :maxdepth: 1 supported_ops supported_jax_ops :func:`yobx.tensorflow.to_onnx` converts a :epkg:`TensorFlow`/:epkg:`Keras` model — or a :epkg:`JAX` function — into an :class:`onnx.ModelProto`. The implementation is a **proof-of-concept** that traces the model with :func:`tensorflow.function` / ``get_concrete_function`` and then converts each TF operation in the resulting computation graph to its ONNX equivalent via a registry of op-level converters. High-level workflow =================== .. code-block:: text Keras model / layer │ ▼ to_onnx() ← builds TensorSpecs, calls get_concrete_function │ ▼ ConcreteFunction (TF computation graph) │ ▼ _convert_concrete_function() │ for every op in the graph … ▼ op converter ← emits ONNX node(s) via GraphBuilder.op.* │ ▼ GraphBuilder.to_onnx() ← validates and returns ModelProto The steps in detail: 1. :func:`to_onnx ` accepts the Keras model, a tuple of representative *numpy* inputs (used to infer dtypes and shapes), and optional ``input_names`` / ``dynamic_shapes``. 2. A :class:`tensorflow.TensorSpec` is built for every input. By default, the batch axis (axis 0) is made dynamic; pass ``dynamic_shapes`` to customise which axes are dynamic. 3. :func:`get_concrete_function` traces the model with those specs, yielding a :class:`tensorflow.ConcreteFunction` whose ``graph`` exposes every individual TF operation in execution order. 4. A fresh :class:`~yobx.xbuilder.GraphBuilder` is created and :func:`_convert_concrete_function` walks the op list: a. **Captured variables** (model weights) are seeded as ONNX initializers. b. **Placeholder ops** that correspond to real inputs are registered via :meth:`make_tensor_input `. c. Every other op is dispatched to a registered converter (or to an entry in ``extra_converters``). 5. The ONNX outputs are declared with :meth:`make_tensor_output `. 6. :meth:`GraphBuilder.to_onnx ` finalises and returns the :class:`onnx.ModelProto`. Quick example (Keras) ===================== .. code-block:: python import numpy as np import tensorflow as tf from yobx.tensorflow import to_onnx model = tf.keras.Sequential([ tf.keras.layers.Dense(8, activation="relu", input_shape=(4,)), tf.keras.layers.Dense(2), ]) X = np.random.rand(5, 4).astype(np.float32) onx = to_onnx(model, (X,)) JAX support =========== :func:`~yobx.tensorflow.to_onnx` also accepts plain :epkg:`JAX` functions. When it detects that the callable is a JAX function (TF tracing raises a ``TypeError`` about abstract arrays), it automatically falls back to :func:`~yobx.tensorflow.tensorflow_helper.jax_to_concrete_function`, which uses :epkg:`jax2tf` to lower the JAX computation to a :class:`~tensorflow.ConcreteFunction` before applying the standard TF→ONNX pipeline. .. code-block:: python import jax.numpy as jnp import numpy as np from yobx.tensorflow import to_onnx def jax_fn(x): return jnp.sin(x) X = np.random.rand(5, 4).astype(np.float32) onx = to_onnx(jax_fn, (X,)) You can also call :func:`~yobx.tensorflow.tensorflow_helper.jax_to_concrete_function` explicitly when you want to inspect or reuse the intermediate :class:`~tensorflow.ConcreteFunction`: .. code-block:: python from yobx.tensorflow import to_onnx from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function import numpy as np cf = jax_to_concrete_function(jax_fn, (X,), dynamic_shapes=({0: "batch"},)) onx = to_onnx(cf, (X,), dynamic_shapes=({0: "batch"},)) See :ref:`l-plot-jax-to-onnx` for a complete gallery example. Converter registry ================== The registry is a module-level dictionary ``TF_OP_CONVERTERS: Dict[str, Callable]`` defined in :mod:`yobx.tensorflow.register`. Keys are TF op-type strings (e.g. ``"MatMul"``, ``"Relu"``); values are converter callables. Registering a converter ----------------------- Use the :func:`register_tf_op_converter ` decorator. Pass a single op-type string or a tuple of strings: .. code-block:: python from yobx.tensorflow.register import register_tf_op_converter from yobx.xbuilder import GraphBuilder @register_tf_op_converter(("MyOp", "MyOpV2")) def convert_my_op(g: GraphBuilder, sts: dict, outputs: list, op) -> str: return g.op.SomeOnnxOp(op.inputs[0].name, outputs=outputs, name=op.name) The decorator raises :class:`TypeError` if any of the given op-type strings is already present in the global registry (including duplicates within the same tuple, since each string is registered before the next one is checked). Looking up a converter ---------------------- :func:`get_tf_op_converter ` accepts an op-type string and returns the registered callable, or ``None`` if none is found. :func:`get_tf_op_converters ` returns a copy of the full registry dictionary. Converter function signature ============================ Every op converter follows the same contract: ``(g, sts, outputs, op[, verbose]) → output_name`` ============= ===================================================== Parameter Description ============= ===================================================== ``g`` :class:`GraphBuilder ` — call ``g.op.(…)`` to emit ONNX nodes. ``sts`` ``Dict`` of metadata (currently always ``{}``). ``outputs`` ``List[str]`` of pre-allocated output tensor names that the converter **must** write to. ``op`` A :class:`tensorflow.Operation` whose ``inputs``, ``outputs``, and attributes describe the TF op. ``verbose`` Optional verbosity level (default ``0``). ============= ===================================================== The function should return the name of the primary output tensor. Input tensor names are obtained via ``op.inputs[i].name`` and attribute values via ``op.get_attr("attr_name")``. Supported ops ============= See :ref:`l-design-tensorflow-supported-ops` for the full list of built-in TF op converters, generated automatically from the live registry. Dynamic shapes ============== By default :func:`to_onnx ` marks axis 0 of every input as dynamic (unnamed batch dimension). To control which axes are dynamic, pass ``dynamic_shapes`` — a tuple of one ``Dict[int, str]`` per input where keys are axis indices and values are symbolic dimension names: .. code-block:: python # axis 0 named "batch", axis 1 fixed onx = to_onnx(model, (X,), dynamic_shapes=({0: "batch"},)) Custom op converters ==================== The ``extra_converters`` parameter of :func:`to_onnx ` accepts a mapping from TF op-type string to converter function. Entries here take **priority** over the built-in registry, making it easy to override or extend coverage without modifying the package. .. code-block:: python import numpy as np import tensorflow as tf from yobx.tensorflow import to_onnx model = tf.keras.Sequential([ tf.keras.layers.Dense(4, activation="relu", input_shape=(3,)) ]) X = np.random.rand(5, 3).astype(np.float32) def custom_relu(g, sts, outputs, op): """Replace Relu with a Clip(0, 1) to saturate outputs at 1.""" import numpy as np return g.op.Clip( op.inputs[0].name, np.array(0.0, dtype=np.float32), np.array(1.0, dtype=np.float32), outputs=outputs[:1], name=op.name, ) onx = to_onnx(model, (X,), extra_converters={"Relu": custom_relu}) Adding a new built-in converter ================================ To extend the built-in op coverage: 1. Create a new file under ``yobx/tensorflow/ops/`` (e.g. ``yobx/tensorflow/ops/reduce.py``). 2. Implement a converter function following the signature above. 3. Decorate it with ``@register_tf_op_converter("ReduceSum")`` (or a tuple for multiple op types). 4. Import the new module inside the ``register()`` function in ``yobx/tensorflow/ops/__init__.py``. .. code-block:: python # yobx/tensorflow/ops/reduce.py import numpy as np from onnx import TensorProto from ..register import register_tf_op_converter from ...xbuilder import GraphBuilder @register_tf_op_converter("Sum") def convert_reduce_sum(g: GraphBuilder, sts: dict, outputs: list, op) -> str: """TF ``Sum`` → ONNX ``ReduceSum``.""" keepdims = int(op.get_attr("keep_dims")) # TF may pass a 0-D scalar for single-axis reductions; ONNX requires 1-D. axes_i64 = g.op.Cast(op.inputs[1].name, to=TensorProto.INT64, name=f"{op.name}_cast") axes = g.op.Reshape(axes_i64, np.array([-1], dtype=np.int64), name=f"{op.name}_axes") return g.op.ReduceSum( op.inputs[0].name, axes, keepdims=keepdims, outputs=outputs, name=op.name )