yobx.tensorflow.tensorflow_helper#
- yobx.tensorflow.tensorflow_helper.get_output_names(model) Sequence[str][source]#
Returns output names for a Keras model or layer.
Note
This POC implementation always returns a single output named
"output". Multi-output models are not yet supported.
- yobx.tensorflow.tensorflow_helper.jax_to_concrete_function(jax_fn: Any, args: Tuple[Any, ...], input_names: Sequence[str] | None = None, dynamic_shapes: Tuple[Dict[int, str], ...] | None = None)[source]#
Converts a JAX function into a
tensorflow.ConcreteFunction.Uses
jax.experimental.jax2tf.convert()to wrap the JAX function as a TensorFlow function, then traces it withtf.functionto produce a concrete computation graph. The resultingConcreteFunctioncan be passed directly toyobx.tensorflow.to_onnx()for ONNX export.- Parameters:
jax_fn – a callable JAX function (or
flax/equinoxmodel wrapped in a plain Python function) whose outputs are JAX arrays.args – dummy inputs as
numpy.ndarrayobjects; used to infer the dtype and static shape of each input tensor.input_names – optional list of names for the ONNX input tensors. When None, inputs are named
"X"(single input) or"X0","X1", … (multiple inputs).dynamic_shapes – optional per-input axis-to-dim-name mappings. Example:
({0: "batch"},)marks axis 0 of the first input as a dynamic (variable-length) dimension. When None, axis 0 of every input is made dynamic by default.
- Returns:
a
tensorflow.ConcreteFunctionready for ONNX export.
Example:
import numpy as np import jax.numpy as jnp from yobx.tensorflow import to_onnx from yobx.tensorflow.tensorflow_helper import jax_to_concrete_function def jax_fn(x): return jnp.sin(x) x = np.random.rand(4, 3).astype(np.float32) cf = jax_to_concrete_function(jax_fn, (x,), dynamic_shapes=({0: "batch"},)) onx = to_onnx(cf, (x,), dynamic_shapes=({0: "batch"},))