yobx.tensorflow.ops.jax_ops#
StableHLO → ONNX op mappings for XlaCallModule conversion.
When a JAX function is lowered through jax2tf, TensorFlow wraps the
computation in an XlaCallModule op whose payload is a StableHLO MLIR
module. get_jax_cvt() is the single look-up point used by the
XlaCallModule converter to find the appropriate ONNX emission callable
for each stablehlo.* op encountered in the parsed MLIR.
Direct mappings (_MAPPING_JAX_ONNX)#
StableHLO unary and binary ops that map 1-to-1 to a single ONNX op.
Composite mappings (_COMPOSITE_JAX_OPS)#
StableHLO ops that require more than one ONNX node; implemented as small
factory functions that close over the GraphBuilderExtendedProtocol
instance.
- yobx.tensorflow.ops.jax_ops.get_jax_cvt(assembly_code: str, g: GraphBuilderExtendedProtocol, jax_type: str)[source]#
Return an ONNX-emission callable for StableHLO jax_type.
- Parameters:
assembly_code – full MLIR text (used only in the error message).
g – the active
GraphBuilderExtendedProtocol.jax_type – StableHLO op name with the
stablehlo.prefix already stripped (e.g."sine","sqrt").
- Raises:
RuntimeError – if jax_type has no registered converter.