torch.onnx.export and a custom operator registered with a function

This example shows how to convert a custom operator, inspired from Python Custom Operators.

A model with a custom ops

import numpy as np
from onnx.printer import to_text
import onnxscript
import torch

We define a model with a custom operator.

def numpy_sin(x: torch.Tensor) -> torch.Tensor:
    assert x.device.type == "cpu"
    x_np = x.numpy()
    return torch.from_numpy(np.sin(x_np))


class ModuleWithACustomOperator(torch.nn.Module):
    def forward(self, x):
        return numpy_sin(x)


model = ModuleWithACustomOperator()

Let’s check it runs.

x = torch.randn(1, 3)
model(x)
tensor([[-0.1049, -0.2033,  0.8686]])

As expected, it does not export.

try:
    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
    print(e)
.numpy() is not supported for tensor subclasses.

The exporter fails with the same eror as it expects torch.export.export to work.

try:
    torch.onnx.export(model, (x,), dynamo=True)
except Exception as e:
    print(e)
[torch.onnx] Obtain model graph for `ModuleWithACustomOperator()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModuleWithACustomOperator()` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModuleWithACustomOperator()` with `torch.export.export(..., strict=True)`...
[torch.onnx] Obtain model graph for `ModuleWithACustomOperator()` with `torch.export.export(..., strict=True)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅

Registration

The exporter how to convert the new exporter into ONNX. This must be defined. The first piece is to tell the exporter that the shape of the output is the same as x. input names must be the same. We also need to rewrite the module to be able to use it.

def register(fct, fct_shape, namespace, fname):
    schema_str = torch.library.infer_schema(fct, mutates_args=())
    custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
    custom_def.register_kernel("cpu")(fct)
    custom_def._abstract_fn = fct_shape


register(numpy_sin, lambda x: torch.empty_like(x), "mylib", "numpy_sin")

We also need to rewrite the module to be able to use it.

class ModuleWithACustomOperator(torch.nn.Module):
    def forward(self, x):
        return torch.ops.mylib.numpy_sin(x)


model = ModuleWithACustomOperator()

Let’s check it runs again.

model(x)
tensor([[-0.1049, -0.2033,  0.8686]])

Let’s see what the fx graph looks like.

print(torch.export.export(model, (x,)).graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %numpy_sin : [num_users=1] = call_function[target=torch.ops.mylib.numpy_sin.default](args = (%x,), kwargs = {})
    return (numpy_sin,)

Next is the conversion to onnx.

op = onnxscript.opset18


@onnxscript.script()
def numpy_sin_to_onnx(x) -> onnxscript.onnx_types.TensorType:
    return op.Sin(x)

And we convert again.

ep = torch.onnx.export(
    model,
    (x,),
    custom_translation_table={torch.ops.mylib.numpy_sin.default: numpy_sin_to_onnx},
    dynamo=True,
)

print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModuleWithACustomOperator()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModuleWithACustomOperator()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
   ir_version: 10,
   opset_import: ["" : 18],
   producer_name: "pytorch",
   producer_version: "2.8.0.dev20250519+cu126"
>
main_graph (float[1,3] x) => (float[1,3] numpy_sin) {
   [n0] numpy_sin = Sin (x)
}

Total running time of the script: (0 minutes 1.434 seconds)

Related examples

torch.onnx.export and a custom operator inplace

torch.onnx.export and a custom operator inplace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

torch.onnx.export and a model with a test

torch.onnx.export and a model with a test

to_onnx and a model with a test

to_onnx and a model with a test

Gallery generated by Sphinx-Gallery