onnx_diagnostic.export.onnx_plug

class onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx(eager_fn: Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], shape_fn: Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], function_proto: FunctionProto | Dict[Any, FunctionProto], n_inputs: int | None = None, n_outputs: int | None = None, name: str | None = None, kwargs: Dict[str, int | float] | None = None, verbose: int = 0, version_selector: Callable[[...], Tuple[Any, ...]] | None = None, default_opset: int = 22)[source][source]

Replaces a piece of code by another one written in ONNX at export time. The function inserts a custom operator and links it to the eager_fn

Parameters:
  • eager_fn – the code it replaces, it must be given in order to be able to execute the torch.fx.Graph the exporter produces

  • shape_fn – the function produces dummy outputs with the shapes the exporter can use for the next operators in the graph

  • function_proto – instances of onnx.FunctionProto, its domain must be onnx_plug

  • n_inputs – number of inputs of the function, if not given, the class will infer it from eager_fn signature, only tensors must be counted

  • n_outputs – same for the number of outputs, only tensors must be counted

  • name – the name of the custom op, the function name if not specified

  • kwargs – constants parameters with their default values

  • version_selector – selects the version based on the arguments, see below for an example, this allows the user to define different onnx version depending on the inputs

  • default_opset – opset to use by default

  • verbose – verbose level

Here is an example:

<<<

import onnx.helper as oh
import torch
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


def demo_customsub(x, y):
    return x - y


def demo_customsub_shape(x, y):
    return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)


def make_function_proto():
    return oh.make_function(
        "onnx_plug",
        "demo_customsub",
        ["x", "y"],
        ["z"],
        [oh.make_node("Sub", ["x", "y"], ["z"])],
        opset_imports=[oh.make_opsetid("", 22)],
    )


class Model(torch.nn.Module):
    def forward(self, x):
        y = x.sum(axis=1, keepdim=True)
        d = torch.ops.onnx_plug.demo_customsub(x, y)
        return torch.abs(d)


replacements = [
    EagerDirectReplacementWithOnnx(
        demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
    )
]

x = torch.randn((3, 4), dtype=torch.float32)
model = Model()
ds = ({0: "d1", 1: "d2"},)

# The exported program shows a custom op.
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
print("ep")

# As the exporter knows how the replace this custom op.
# Let's export.

onx = to_onnx(
    model,
    (x,),
    dynamic_shapes=ds,
    exporter="custom",
    onnx_plugs=replacements,
    target_opset=22,
    inline=False,
).model_proto

print(pretty_onnx(onx))

>>>

    ep
    opset: domain='' version=22
    opset: domain='onnx_plug' version=1
    input: name='x' type=dtype('float32') shape=['d1', 'd2']
    init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape
    ReduceSum(x, init7_s1_1, keepdims=1) -> sum_1
      demo_customsub[onnx_plug](x, sum_1) -> demo_customsub
        Abs(demo_customsub) -> output_0
    output: name='output_0' type=dtype('float32') shape=['d1', 'd2']
    ----- function name=demo_customsub domain=onnx_plug
    opset: domain='' version=22
    input: 'x'
    input: 'y'
    Sub(x, y) -> z
    output: name='z' type=? shape=?

We do the same with torch.onnx.export():

<<<

import onnx.helper as oh
import torch
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


def demo_customsub(x, y):
    return x - y


def demo_customsub_shape(x, y):
    return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)


def make_function_proto():
    return oh.make_function(
        "onnx_plug",
        "demo_customsub",
        ["x", "y"],
        ["z"],
        [oh.make_node("Sub", ["x", "y"], ["z"])],
        opset_imports=[oh.make_opsetid("", 22)],
    )


class Model(torch.nn.Module):
    def forward(self, x):
        y = x.sum(axis=1, keepdim=True)
        d = torch.ops.onnx_plug.demo_customsub(x, y)
        return torch.abs(d)


replacements = [
    EagerDirectReplacementWithOnnx(
        demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
    )
]

x = torch.randn((3, 4), dtype=torch.float32)
model = Model()
ds = ({0: "d1", 1: "d2"},)

# The exported program shows a custom op.
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
print("ep")

# As the exporter knows how the replace this custom op.
# Let's export.

onx = to_onnx(
    model,
    (x,),
    dynamic_shapes=ds,
    exporter="onnx-dynamo",
    onnx_plugs=replacements,
    target_opset=22,
    inline=False,
).model_proto

print(pretty_onnx(onx))

>>>

    ep
    ~/github/onnx-diagnostic/onnx_diagnostic/export/api.py:175: UserWarning: Exporting a model while it is in training mode. Please ensure that this is intended, as it may lead to different behavior during inference. Calling model.eval() before export is recommended.
      epo = torch.onnx.export(
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    /usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
      return cls.__new__(cls, *args)
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    Shape inference failed: %s. Model is left unchanged
    Traceback (most recent call last):
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call
        inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api
        result = func(proto)
                 ^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes
        return onnx.shape_inference.infer_shapes(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes
        inferred_model_str = C.infer_shapes(
                             ^^^^^^^^^^^^^^^
    onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null
    
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    Shape inference failed: %s. Model is left unchanged
    Traceback (most recent call last):
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call
        inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api
        result = func(proto)
                 ^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes
        return onnx.shape_inference.infer_shapes(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes
        inferred_model_str = C.infer_shapes(
                             ^^^^^^^^^^^^^^^
    onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null
    
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    Shape inference failed: %s. Model is left unchanged
    Traceback (most recent call last):
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call
        inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api
        result = func(proto)
                 ^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes
        return onnx.shape_inference.infer_shapes(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes
        inferred_model_str = C.infer_shapes(
                             ^^^^^^^^^^^^^^^
    onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null
    
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    opset: domain='' version=22
    opset: domain='onnx_plug' version=1
    input: name='x' type=dtype('float32') shape=['d1', 'd2']
    init: name='val_3' type=int64 shape=(1,) -- array([1])
    ReduceSum(x, val_3, noop_with_empty_axes=0, keepdims=1) -> sum_1
      demo_customsub[onnx_plug](x, sum_1) -> demo_customsub
        Abs(demo_customsub) -> abs_1
    output: name='abs_1' type=dtype('float32') shape=['d1', 'd2']
    ----- function name=demo_customsub domain=onnx_plug
    opset: domain='' version=22
    input: 'x'
    input: 'y'
    Sub(x, y) -> z
    output: name='z' type=? shape=?

This shows how to define multiple versions depending on the device, the type or the targeted onnx opset.

def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
    first_tensor = next(a for a in args if a is not None)
    dtype = first_tensor.dtype
    itype = torch_dtype_to_onnx_dtype(dtype)
    if dtype == torch.float32:
        if opset >= 23:
            return "LOOPA23", itype
        return "LOOPMHA", itype
    if dtype == torch.float16:
        if first_tensor.is_cuda:
            return "PACKED", itype
        return "LOOPMHA", itype
    raise AssertionError(
        f"Unable to handle type {torch.dtype} (itype={itype}) "
        f"on device {torch.device} with opset={opset}"
    )

qwen_sdpa_attention_versatile = EagerDirectReplacementWithOnnx(
    qwen_sdpa_attention,
    lambda qs, *args, **kwargs: torch.empty(
        (qs.shape[0], qs.shape[2], qs.shape[1], qs.shape[3]),
        dtype=qs.dtype,
        device=qs.device,
    ),
    {
        ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
            PackedAttention.to_function_proto()
        ),
        ("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
        ("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
            onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
        ),
        ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
            LoopMHAAttention.to_function_proto()
        ),
        ("LOOPMHA", onnx.TensorProto.FLOAT16): _update_sequence_type(
            onnx.TensorProto.FLOAT16,
            _add_com_microsoft_opset(LoopMHAAttention.to_function_proto()),
        ),
    },
    n_inputs=4,
    n_outputs=1,
    kwargs=dict(scaling=0.11180339887498948, num_heads=16),
    name="qwen_sdpa_attention_versatile",
    version_selector=qwen_version_selector,
)
custom_converter() Callable[source][source]

Returns a function which converts a custom ops found in the fx graph into ONNX following the API of the custom exporter. The converter adds a custom op and registers the local function.

property domain: str

Returns the onnx domain.

get_function_proto(opset: int, *args) FunctionProto[source][source]

Returns the correct version based on the inputs.

onnx_dynamo_converter() Callable[source][source]

Returns a function which which converts a custom ops found in the fx graph into ONNX following the API of torch.onnx.export().

property target_name: str

Returns the target name (see in the exported program).

property torch_op: Callable

Returns torch.ops.onny_plug.<name>.

verify(*args, engine: Callable | None = None, dump_onnx_model: str | None = None, opset: int = 22, **kwargs) VerifyResult[source][source]

Verifies that the eager mode is equivalent to the onnx function given as a replacements. This function evaluates eager_fn, checks that the shapes are equivalent to the ones given by shape_fn, and finally evaluates the onnx translation if the previous did not fail.

Parameters:
  • args – function inputs

  • kwargs – arguments for eager_fn

  • engine – by default an instance of onnx_diagnostic.reference.OnnxruntimeEvaluator.

  • dump_onnx_model – to dump the onnx model used to verify eager and onnx produce the same results

  • opset – onnx opset to use

  • kwargs – additional arguments to the function

Returns:

outputs of onnx_diagnostic.helpers.max_diff()

class onnx_diagnostic.export.onnx_plug.VerifyResult(eager_outputs: Tuple[Tensor, ...], onnx_outputs: Tuple[Tensor, ...], diffs: Tuple[Dict[str, float], ...])[source][source]

Outputs of method verify.

onnx_diagnostic.export.onnx_plug.is_exporting() bool[source][source]

Returns torch.compiler.is_exporting() or torch.compiler.is_compiling(). Changes _TEST_EXPORT to make it trigger.