Supported JAX Ops#

When a JAX function is converted to ONNX via to_onnx(), the JAX computation is first lowered to a XlaCallModule TensorFlow op whose payload contains a StableHLO MLIR module. The converter parses that module op-by-op and maps each stablehlo.* operator to an ONNX node.

The tables below list every stablehlo op name (after stripping the stablehlo. prefix) that is currently supported, together with the corresponding ONNX op or sub-graph it is lowered to.

Direct mappings#

These stablehlo ops map to a single ONNX op with the same semantics:

<<<

from yobx.tensorflow.ops.jax_ops import _MAPPING_JAX_ONNX

# Group by ONNX op name for a more readable table
rows = sorted(_MAPPING_JAX_ONNX.items())  # (jax_op, onnx_op)
print(".. list-table::")
print("   :header-rows: 1")
print("   :widths: 40 60")
print()
print("   * - StableHLO op (``stablehlo.<name>``)")
print("     - ONNX op")
for jax_op, onnx_op in rows:
    print(f"   * - ``{jax_op}``")
    print(f"     - ``{onnx_op}``")
print()

>>>

StableHLO op (stablehlo.<name>)

ONNX op

abs

Abs

add

Add

and

And

ceil

Ceil

compare_EQ

Equal

compare_GE

GreaterOrEqual

compare_GT

Greater

compare_LE

LessOrEqual

compare_LT

Less

cosine

Cos

divide

Div

exponential

Exp

floor

Floor

log

Log

logistic

Sigmoid

maximum

Max

minimum

Min

multiply

Mul

negate

Neg

not

Not

or

Or

power

Pow

remainder

Mod

round_nearest_even

Round

select

Where

sign

Sign

sine

Sin

sqrt

Sqrt

subtract

Sub

tanh

Tanh

xor

Xor

Composite mappings#

These stablehlo ops require more than one ONNX node and are implemented as small sub-graphs by dedicated factory functions:

<<<

from yobx.tensorflow.ops.jax_ops import _COMPOSITE_JAX_OPS

_descriptions = {
    "rsqrt": "``Reciprocal(Sqrt(x))``",
    "log_plus_one": "``Log(Add(x, 1))``",
    "exponential_minus_one": "``Sub(Exp(x), 1)``",
}

print(".. list-table::")
print("   :header-rows: 1")
print("   :widths: 40 60")
print()
print("   * - StableHLO op (``stablehlo.<name>``)")
print("     - ONNX equivalent")
for jax_op in sorted(_COMPOSITE_JAX_OPS):
    desc = _descriptions.get(jax_op, "*(see source)*")
    print(f"   * - ``{jax_op}``")
    print(f"     - {desc}")
print()

>>>

StableHLO op (stablehlo.<name>)

ONNX equivalent

compare_NE

(see source)

exponential_minus_one

Sub(Exp(x), 1)

log_plus_one

Log(Add(x, 1))

rsqrt

Reciprocal(Sqrt(x))

Adding a new JAX op mapping#

To add support for an additional stablehlo unary op:

  1. If the op maps 1-to-1 to an ONNX op, add an entry to _MAPPING_JAX_ONNX in yobx.tensorflow.ops.xla_call_module:

    _MAPPING_JAX_ONNX["cbrt"] = "some_onnx_op"  # if a direct match exists
    
  2. If the op requires multiple ONNX nodes, add a _make_<name> factory function and register it in _COMPOSITE_JAX_OPS:

    def _make_cbrt(g):
        import numpy as np
        def _cbrt(*args, **kwargs):
            name = kwargs.pop("name", "cbrt")
            outputs = kwargs.pop("outputs", None)
            (x,) = args
            exp = np.array(1.0 / 3.0, dtype=np.float32)
            kw = {"name": name}
            if outputs is not None:
                kw["outputs"] = outputs
            return g.op.Pow(x, exp, **kw)
        return _cbrt
    
    _COMPOSITE_JAX_OPS["cbrt"] = _make_cbrt