Supported JAX Ops#
When a JAX function is converted to ONNX via
to_onnx(), the JAX computation is first lowered to a
XlaCallModule TensorFlow op whose payload contains a StableHLO MLIR
module. The converter parses that module op-by-op and maps each
stablehlo.* operator to an ONNX node.
The tables below list every stablehlo op name (after stripping the
stablehlo. prefix) that is currently supported, together with the
corresponding ONNX op or sub-graph it is lowered to. Each ONNX op name
links to its entry in the ONNX Operators specification, and each
StableHLO op name links to the StableHLO specification.
Direct mappings#
These stablehlo ops map to a single ONNX op with the same semantics:
<<<
from yobx.tensorflow.ops.jax_ops import _MAPPING_JAX_ONNX
_ONNX_BASE = "https://onnx.ai/onnx/operators/onnx__{}.html"
_STABLEHLO_BASE = "https://openxla.org/stablehlo/spec#{}"
def _hlo_anchor(jax_op):
# compare_EQ/GT/GE/LT/LE/NE are synthetic names for stablehlo.compare
if jax_op.startswith("compare_"):
return "compare"
return jax_op.replace("_", "-")
# Group by ONNX op name for a more readable table
rows = sorted(_MAPPING_JAX_ONNX.items()) # (jax_op, onnx_op)
print(".. list-table::")
print(" :header-rows: 1")
print(" :widths: 40 60")
print()
print(" * - StableHLO op (``stablehlo.<name>``)")
print(" - ONNX op")
for jax_op, onnx_op in rows:
hlo_url = _STABLEHLO_BASE.format(_hlo_anchor(jax_op))
onnx_url = _ONNX_BASE.format(onnx_op)
print(f" * - `{jax_op} <{hlo_url}>`_")
print(f" - `{onnx_op} <{onnx_url}>`_")
print()
>>>
StableHLO op ( |
ONNX op |
|---|---|
Composite mappings#
These stablehlo ops require more than one ONNX node and are implemented
as small sub-graphs by dedicated factory functions:
<<<
from yobx.tensorflow.ops.jax_ops import _COMPOSITE_JAX_OPS
_ONNX_BASE = "https://onnx.ai/onnx/operators/onnx__{}.html"
_STABLEHLO_BASE = "https://openxla.org/stablehlo/spec#{}"
def _hlo_anchor(jax_op):
if jax_op.startswith("compare_"):
return "compare"
return jax_op.replace("_", "-")
def _onnx_link(op):
return f"`{op} <{_ONNX_BASE.format(op)}>`_"
_descriptions = {
"rsqrt": f"{_onnx_link('Reciprocal')} ( {_onnx_link('Sqrt')} (x) )",
"log_plus_one": f"{_onnx_link('Log')} ( {_onnx_link('Add')} (x, 1) )",
"exponential_minus_one": f"{_onnx_link('Sub')} ( {_onnx_link('Exp')} (x), 1 )",
"compare_NE": f"{_onnx_link('Not')} ( {_onnx_link('Equal')} (a, b) )",
}
print(".. list-table::")
print(" :header-rows: 1")
print(" :widths: 40 60")
print()
print(" * - StableHLO op (``stablehlo.<name>``)")
print(" - ONNX equivalent")
for jax_op in sorted(_COMPOSITE_JAX_OPS):
hlo_url = _STABLEHLO_BASE.format(_hlo_anchor(jax_op))
desc = _descriptions.get(jax_op, "*(see source)*")
print(f" * - `{jax_op} <{hlo_url}>`_")
print(f" - {desc}")
print()
>>>
StableHLO op ( |
ONNX equivalent |
|---|---|
Reciprocal ( Sqrt (x) ) |
Structural ops#
The following stablehlo ops are handled directly by the
XlaCallModule converter in yobx.tensorflow.ops.xla_call_module
and do not go through get_jax_cvt():
<<<
from yobx.tensorflow.ops.xla_call_module import _STRUCTURAL_OPS
_ONNX_BASE = "https://onnx.ai/onnx/operators/onnx__{}.html"
_STABLEHLO_BASE = "https://openxla.org/stablehlo/spec#{}"
def _hlo_anchor(op):
# Use generic anchors for ops that share a spec section
_overrides = {
"reduce_max": "reduce",
"reduce_sum": "reduce",
}
return _overrides.get(op, op.replace("_", "-"))
print(".. list-table::")
print(" :header-rows: 1")
print(" :widths: 40 60")
print()
print(" * - StableHLO op")
print(" - ONNX equivalent")
for hlo_op, (onnx_op, desc, _handler) in sorted(_STRUCTURAL_OPS.items()):
hlo_url = _STABLEHLO_BASE.format(_hlo_anchor(hlo_op))
hlo_cell = f"`{hlo_op} <{hlo_url}>`_"
if onnx_op is not None:
onnx_url = _ONNX_BASE.format(onnx_op)
onnx_cell = f"`{onnx_op} <{onnx_url}>`_ — {desc}"
else:
onnx_cell = desc
print(f" * - {hlo_cell}")
print(f" - {onnx_cell}")
print()
>>>
StableHLO op |
ONNX equivalent |
|---|---|
identity pass-through (ONNX broadcasting is implicit) |
|
inlined private function (no ONNX op emitted) |
|
ONNX initializer (weight tensor) |
|
Cast — type cast |
|
MatMul — matrix multiply |
|
identity pass-through (ONNX broadcasting is implicit) |
|
ReduceMax — reduce along axes keeping dims |
|
ReduceSum — reduce along axes keeping dims |
Adding a new JAX op mapping#
To add support for an additional stablehlo unary op:
If the op maps 1-to-1 to an ONNX op, add an entry to
_MAPPING_JAX_ONNXinyobx.tensorflow.ops.xla_call_module:_MAPPING_JAX_ONNX["cbrt"] = "some_onnx_op" # if a direct match exists
If the op requires multiple ONNX nodes, add a
_make_<name>factory function and register it in_COMPOSITE_JAX_OPS:def _make_cbrt(g): import numpy as np def _cbrt(*args, **kwargs): name = kwargs.pop("name", "cbrt") outputs = kwargs.pop("outputs", None) (x,) = args exp = np.array(1.0 / 3.0, dtype=np.float32) kw = {"name": name} if outputs is not None: kw["outputs"] = outputs return g.op.Pow(x, exp, **kw) return _cbrt _COMPOSITE_JAX_OPS["cbrt"] = _make_cbrt