onnx_diagnostic.export.onnx_plug

class onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx(eager_fn: Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], shape_fn: Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], function_proto: FunctionProto, n_inputs: int | None = None, n_outputs: int | None = None, name: str | None = None, kwargs: Dict[str, int | float] | None = None, verbose: int = 0)[source][source]

Replaces a piece of code by another one written in ONNX at export time. The function inserts a custom operator and links it to the eager_fn

Parameters:
  • eager_fn – the code it replaces, it must be given in order to be able to execute the torch.fx.Graph the exporter produces

  • shape_fn – the function produces dummy outputs with the shapes the exporter can use for the next operators in the graph

  • function_proto – instances of onnx.FunctionProto, its domain must be onnx_plug

  • n_inputs – number of inputs of the function, if not given, the class will infer it from eager_fn signature, only tensors must be counted

  • n_outputs – same for the number of outputs, only tensors must be counted

  • name – the name of the custom op, the function name if not specified

  • kwargs – constants parameters with their default values

  • verbose – verbose level

Here is an example:

<<<

import onnx.helper as oh
import torch
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


def demo_customsub(x, y):
    return x - y


def demo_customsub_shape(x, y):
    return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)


def make_function_proto():
    return oh.make_function(
        "onnx_plug",
        "demo_customsub",
        ["x", "y"],
        ["z"],
        [oh.make_node("Sub", ["x", "y"], ["z"])],
        opset_imports=[oh.make_opsetid("", 22)],
    )


class Model(torch.nn.Module):
    def forward(self, x):
        y = x.sum(axis=1, keepdim=True)
        d = torch.ops.onnx_plug.demo_customsub(x, y)
        return torch.abs(d)


replacements = [
    EagerDirectReplacementWithOnnx(
        demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
    )
]

x = torch.randn((3, 4), dtype=torch.float32)
model = Model()
ds = ({0: "d1", 1: "d2"},)

# The exported program shows a custom op.
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
print("ep")

# As the exporter knows how the replace this custom op.
# Let's export.

onx = to_onnx(
    model,
    (x,),
    dynamic_shapes=ds,
    exporter="custom",
    onnx_plugs=replacements,
    target_opset=22,
    inline=False,
).model_proto

print(pretty_onnx(onx))

# And with :func:`torch.onnx.export`:

onx = to_onnx(
    model,
    (x,),
    dynamic_shapes=ds,
    exporter="onnx-dynamo",
    onnx_plugs=replacements,
    target_opset=22,
    inline=False,
).model_proto

print(pretty_onnx(onx))

>>>

    ep
    opset: domain='' version=22
    opset: domain='onnx_plug' version=1
    input: name='x' type=dtype('float32') shape=['d1', 'd2']
    init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape
    ReduceSum(x, init7_s1_1, keepdims=1) -> sum_1
      demo_customsub[onnx_plug](x, sum_1) -> demo_customsub
        Abs(demo_customsub) -> output_0
    output: name='output_0' type=dtype('float32') shape=['d1', 'd2']
    ----- function name=demo_customsub domain=onnx_plug
    opset: domain='' version=22
    input: 'x'
    input: 'y'
    Sub(x, y) -> z
    output: name='z' type=? shape=?
    ~/github/onnx-diagnostic/onnx_diagnostic/export/api.py:161: UserWarning: Exporting a model while it is in training mode. Please ensure that this is intended, as it may lead to different behavior during inference. Calling model.eval() before export is recommended.
      epo = torch.onnx.export(
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    /usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
      return cls.__new__(cls, *args)
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    Shape inference failed: %s. Model is left unchanged
    Traceback (most recent call last):
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call
        inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api
        result = func(proto)
                 ^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes
        return onnx.shape_inference.infer_shapes(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes
        inferred_model_str = C.infer_shapes(
                             ^^^^^^^^^^^^^^^
    onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null
    
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    Shape inference failed: %s. Model is left unchanged
    Traceback (most recent call last):
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call
        inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api
        result = func(proto)
                 ^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes
        return onnx.shape_inference.infer_shapes(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes
        inferred_model_str = C.infer_shapes(
                             ^^^^^^^^^^^^^^^
    onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null
    
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    Shape inference failed: %s. Model is left unchanged
    Traceback (most recent call last):
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call
        inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api
        result = func(proto)
                 ^^^^^^^^^^^
      File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes
        return onnx.shape_inference.infer_shapes(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes
        inferred_model_str = C.infer_shapes(
                             ^^^^^^^^^^^^^^^
    onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null
    
    The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization
    opset: domain='' version=22
    opset: domain='onnx_plug' version=1
    input: name='x' type=dtype('float32') shape=['d1', 'd2']
    init: name='val_3' type=int64 shape=(1,) -- array([1])
    ReduceSum(x, val_3, noop_with_empty_axes=0, keepdims=1) -> sum_1
      demo_customsub[onnx_plug](x, sum_1) -> demo_customsub
        Abs(demo_customsub) -> abs_1
    output: name='abs_1' type=dtype('float32') shape=['d1', 'd2']
    ----- function name=demo_customsub domain=onnx_plug
    opset: domain='' version=22
    input: 'x'
    input: 'y'
    Sub(x, y) -> z
    output: name='z' type=? shape=?
custom_converter() Callable[source][source]

Returns a function which converts a custom ops found in the fx graph into ONNX following the API of the custom exporter. The converter adds a custom op and registers the local function.

property domain: str

Returns the onnx domain.

onnx_dynamo_converter() Callable[source][source]

Returns a function which which converts a custom ops found in the fx graph into ONNX following the API of torch.onnx.export().

property target_name: str

Returns the target name (see in the exported program).

property torch_op: Callable

Returns torch.ops.onny_plug.<name>.

verify(*args, engine: Callable | None = None, dump_onnx_model: str | None = None, **kwargs) VerifyResult[source][source]

Verifies that the eager mode is equivalent to the onnx function given as a replacements. This function evaluates eager_fn, checks that the shapes are equivalent to the ones given by shape_fn, and finally evaluates the onnx translation if the previous did not fail.

Parameters:
  • args – function inputs

  • kwargs – arguments for eager_fn

  • engine – by default an instance of onnx_diagnostic.reference.OnnxruntimeEvaluator.

  • dump_onnx_model – to dump the onnx model used to verify eager and onnx produce the same results

  • kwargs – additional arguments to the function

Returns:

outputs of onnx_diagnostic.helpers.max_diff()

class onnx_diagnostic.export.onnx_plug.VerifyResult(eager_outputs: Tuple[Tensor, ...], onnx_outputs: Tuple[Tensor, ...], diffs: Tuple[Dict[str, float], ...])[source][source]

Outputs of method verify.

onnx_diagnostic.export.onnx_plug.is_exporting() bool[source][source]

Returns torch.compiler.is_exporting() or torch.compiler.is_compiling(). Changes _TEST_EXPORT to make it trigger.