onnx_diagnostic.export.onnx_plug¶
- class onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx(eager_fn: Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], shape_fn: Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]], function_proto: FunctionProto, n_inputs: int | None = None, n_outputs: int | None = None, name: str | None = None, kwargs: Dict[str, int | float] | None = None, verbose: int = 0)[source][source]¶
Replaces a piece of code by another one written in ONNX at export time. The function inserts a custom operator and links it to the eager_fn
- Parameters:
eager_fn – the code it replaces, it must be given in order to be able to execute the torch.fx.Graph the exporter produces
shape_fn – the function produces dummy outputs with the shapes the exporter can use for the next operators in the graph
function_proto – instances of
onnx.FunctionProto, its domain must beonnx_plugn_inputs – number of inputs of the function, if not given, the class will infer it from eager_fn signature, only tensors must be counted
n_outputs – same for the number of outputs, only tensors must be counted
name – the name of the custom op, the function name if not specified
kwargs – constants parameters with their default values
verbose – verbose level
Here is an example:
<<<
import onnx.helper as oh import torch from onnx_diagnostic.helpers.onnx_helper import pretty_onnx from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx from onnx_diagnostic.export.api import to_onnx from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str def demo_customsub(x, y): return x - y def demo_customsub_shape(x, y): return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype) def make_function_proto(): return oh.make_function( "onnx_plug", "demo_customsub", ["x", "y"], ["z"], [oh.make_node("Sub", ["x", "y"], ["z"])], opset_imports=[oh.make_opsetid("", 22)], ) class Model(torch.nn.Module): def forward(self, x): y = x.sum(axis=1, keepdim=True) d = torch.ops.onnx_plug.demo_customsub(x, y) return torch.abs(d) replacements = [ EagerDirectReplacementWithOnnx( demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1 ) ] x = torch.randn((3, 4), dtype=torch.float32) model = Model() ds = ({0: "d1", 1: "d2"},) # The exported program shows a custom op. ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds)) print("ep") # As the exporter knows how the replace this custom op. # Let's export. onx = to_onnx( model, (x,), dynamic_shapes=ds, exporter="custom", onnx_plugs=replacements, target_opset=22, inline=False, ).model_proto print(pretty_onnx(onx)) # And with :func:`torch.onnx.export`: onx = to_onnx( model, (x,), dynamic_shapes=ds, exporter="onnx-dynamo", onnx_plugs=replacements, target_opset=22, inline=False, ).model_proto print(pretty_onnx(onx))
>>>
ep opset: domain='' version=22 opset: domain='onnx_plug' version=1 input: name='x' type=dtype('float32') shape=['d1', 'd2'] init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape ReduceSum(x, init7_s1_1, keepdims=1) -> sum_1 demo_customsub[onnx_plug](x, sum_1) -> demo_customsub Abs(demo_customsub) -> output_0 output: name='output_0' type=dtype('float32') shape=['d1', 'd2'] ----- function name=demo_customsub domain=onnx_plug opset: domain='' version=22 input: 'x' input: 'y' Sub(x, y) -> z output: name='z' type=? shape=? ~/github/onnx-diagnostic/onnx_diagnostic/export/api.py:161: UserWarning: Exporting a model while it is in training mode. Please ensure that this is intended, as it may lead to different behavior during inference. Calling model.eval() before export is recommended. epo = torch.onnx.export( [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... [torch.onnx] Obtain model graph for `Model()` with `torch.export.export(..., strict=False)`... ✅ [torch.onnx] Run decomposition... /usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return cls.__new__(cls, *args) [torch.onnx] Run decomposition... ✅ [torch.onnx] Translate the graph into ONNX... [torch.onnx] Translate the graph into ONNX... ✅ The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization Shape inference failed: %s. Model is left unchanged Traceback (most recent call last): File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api result = func(proto) ^^^^^^^^^^^ File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes return onnx.shape_inference.infer_shapes( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes inferred_model_str = C.infer_shapes( ^^^^^^^^^^^^^^^ onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization Shape inference failed: %s. Model is left unchanged Traceback (most recent call last): File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api result = func(proto) ^^^^^^^^^^^ File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes return onnx.shape_inference.infer_shapes( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes inferred_model_str = C.infer_shapes( ^^^^^^^^^^^^^^^ onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization Shape inference failed: %s. Model is left unchanged Traceback (most recent call last): File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 83, in call inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/ir-py/src/onnx_ir/passes/common/_c_api_utils.py", line 64, in call_onnx_api result = func(proto) ^^^^^^^^^^^ File "~/github/ir-py/src/onnx_ir/passes/common/shape_inference.py", line 75, in partial_infer_shapes return onnx.shape_inference.infer_shapes( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/onnx/onnx/shape_inference.py", line 58, in infer_shapes inferred_model_str = C.infer_shapes( ^^^^^^^^^^^^^^^ onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Abs, node name: node_abs_1): [TypeInferenceError] Input 0 expected to have type but instead is null The value type for shape [d1,d2] is not known. Please set type for the value. Skipping serialization opset: domain='' version=22 opset: domain='onnx_plug' version=1 input: name='x' type=dtype('float32') shape=['d1', 'd2'] init: name='val_3' type=int64 shape=(1,) -- array([1]) ReduceSum(x, val_3, noop_with_empty_axes=0, keepdims=1) -> sum_1 demo_customsub[onnx_plug](x, sum_1) -> demo_customsub Abs(demo_customsub) -> abs_1 output: name='abs_1' type=dtype('float32') shape=['d1', 'd2'] ----- function name=demo_customsub domain=onnx_plug opset: domain='' version=22 input: 'x' input: 'y' Sub(x, y) -> z output: name='z' type=? shape=?- custom_converter() Callable[source][source]¶
Returns a function which converts a custom ops found in the fx graph into ONNX following the API of the custom exporter. The converter adds a custom op and registers the local function.
- onnx_dynamo_converter() Callable[source][source]¶
Returns a function which which converts a custom ops found in the fx graph into ONNX following the API of
torch.onnx.export().
- verify(*args, engine: Callable | None = None, dump_onnx_model: str | None = None, **kwargs) VerifyResult[source][source]¶
Verifies that the eager mode is equivalent to the onnx function given as a replacements. This function evaluates eager_fn, checks that the shapes are equivalent to the ones given by shape_fn, and finally evaluates the onnx translation if the previous did not fail.
- Parameters:
args – function inputs
kwargs – arguments for eager_fn
engine – by default an instance of
onnx_diagnostic.reference.OnnxruntimeEvaluator.dump_onnx_model – to dump the onnx model used to verify eager and onnx produce the same results
kwargs – additional arguments to the function
- Returns:
outputs of
onnx_diagnostic.helpers.max_diff()
- class onnx_diagnostic.export.onnx_plug.VerifyResult(eager_outputs: Tuple[Tensor, ...], onnx_outputs: Tuple[Tensor, ...], diffs: Tuple[Dict[str, float], ...])[source][source]¶
Outputs of method
verify.
- onnx_diagnostic.export.onnx_plug.is_exporting() bool[source][source]¶
Returns
torch.compiler.is_exporting()ortorch.compiler.is_compiling(). Changes_TEST_EXPORTto make it trigger.