yobx.tensorflow.tensorflow_helper#

yobx.tensorflow.tensorflow_helper.get_output_names(model) Sequence[str][source]#

Returns output names for a Keras model or layer.

Note

This POC implementation always returns a single output named "output". Multi-output models are not yet supported.

yobx.tensorflow.tensorflow_helper.jax_to_concrete_function(jax_fn: Any, args: Tuple[Any, ...], input_names: Sequence[str] | None = None, dynamic_shapes: Tuple[Dict[int, str], ...] | None = None)[source]#

Converts a JAX function into a tensorflow.ConcreteFunction.

Uses jax.experimental.jax2tf.convert() to wrap the JAX function as a TensorFlow function, then traces it with tf.function to produce a concrete computation graph. The resulting ConcreteFunction can be passed directly to yobx.tensorflow.to_onnx() for ONNX export.

Parameters:
  • jax_fn – a callable JAX function (or flax/equinox model wrapped in a plain Python function) whose outputs are JAX arrays.

  • args – dummy inputs as numpy.ndarray objects; used to infer the dtype and static shape of each input tensor.

  • input_names – optional list of names for the ONNX input tensors. When None, inputs are named "X" (single input) or "X0", "X1", … (multiple inputs).

  • dynamic_shapes – optional per-input axis-to-dim-name mappings. Example: ({0: "batch"},) marks axis 0 of the first input as a dynamic (variable-length) dimension. When None, axis 0 of every input is made dynamic by default.

Returns:

a tensorflow.ConcreteFunction ready for ONNX export.

Example:

import numpy as np
import jax.numpy as jnp
from yobx.tensorflow import to_onnx
from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function

def jax_fn(x):
    return jnp.sin(x)

x = np.random.rand(4, 3).astype(np.float32)
cf = jax_to_concrete_function(jax_fn, (x,), dynamic_shapes=({0: "batch"},))
onx = to_onnx(cf, (x,), dynamic_shapes=({0: "batch"},))
yobx.tensorflow.tensorflow_helper.tf_dtype_to_np_dtype(tf_dtype)[source]#

Converts a TensorFlow dtype to a numpy dtype.