yobx.tensorflow.ops.jax_ops#

StableHLO → ONNX op mappings for XlaCallModule conversion.

When a JAX function is lowered through jax2tf, TensorFlow wraps the computation in an XlaCallModule op whose payload is a StableHLO MLIR module. get_jax_cvt() is the single look-up point used by the XlaCallModule converter to find the appropriate ONNX emission callable for each stablehlo.* op encountered in the parsed MLIR.

Direct mappings (_MAPPING_JAX_ONNX)#

StableHLO unary and binary ops that map 1-to-1 to a single ONNX op.

Composite mappings (_COMPOSITE_JAX_OPS)#

StableHLO ops that require more than one ONNX node; implemented as small factory functions that close over the GraphBuilderExtendedProtocol instance.

yobx.tensorflow.ops.jax_ops.get_jax_cvt(assembly_code: str, g: GraphBuilderExtendedProtocol, jax_type: str)[source]#

Return an ONNX-emission callable for StableHLO jax_type.

Parameters:
  • assembly_code – full MLIR text (used only in the error message).

  • g – the active GraphBuilderExtendedProtocol.

  • jax_type – StableHLO op name with the stablehlo. prefix already stripped (e.g. "sine", "sqrt").

Raises:

RuntimeError – if jax_type has no registered converter.