Note
Go to the end to download the full example code.
to_onnx and a custom operator registered with a function¶
This example shows how to convert a custom operator, inspired from Python Custom Operators.
A model with a custom ops¶
from typing import Any, Dict, List
import numpy as np
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx, Dispatcher
We define a model with a custom operator.
def numpy_sin(x: torch.Tensor) -> torch.Tensor:
assert x.device.type == "cpu"
x_np = x.numpy()
return torch.from_numpy(np.sin(x_np))
class ModuleWithACustomOperator(torch.nn.Module):
def forward(self, x):
return numpy_sin(x)
model = ModuleWithACustomOperator()
Let’s check it runs.
x = torch.randn(1, 3)
model(x)
tensor([[-0.9442, 0.5768, -0.9167]])
As expected, it does not export.
try:
torch.export.export(model, (x,))
raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
print(e)
This export should failed unless pytorch now supports this model.
The exporter fails with the same eror as it expects torch.export.export to work.
Registration¶
The exporter how to convert the new exporter into ONNX. This must be defined. The first piece is to tell the exporter that the shape of the output is the same as x. input names must be the same.
def register(fct, fct_shape, namespace, fname):
schema_str = torch.library.infer_schema(fct, mutates_args=())
custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
custom_def.register_kernel("cpu")(fct)
custom_def._abstract_fn = fct_shape
register(numpy_sin, lambda x: torch.empty_like(x), "mylib", "numpy_sin")
class ModuleWithACustomOperator(torch.nn.Module):
def forward(self, x):
return torch.ops.mylib.numpy_sin(x)
model = ModuleWithACustomOperator()
Let’s check it runs again.
model(x)
tensor([[-0.9442, 0.5768, -0.9167]])
Let’s see what the fx graph looks like.
print(torch.export.export(model, (x,)).graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%numpy_sin : [num_users=1] = call_function[target=torch.ops.mylib.numpy_sin.default](args = (%x,), kwargs = {})
return (numpy_sin,)
Next is the conversion to onnx.
We create a Dispatcher
.
dispatcher = Dispatcher({"mylib::numpy_sin": numpy_sin_to_onnx})
And we convert again.
onx = to_onnx(model, (x,), dispatcher=dispatcher, optimize=False)
print(pretty_onnx(onx))
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 3]
Sin(x) -> numpy_sin
Identity(numpy_sin) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
And we convert again with optimization this time.
onx = to_onnx(model, (x,), dispatcher=dispatcher, optimize=True)
print(pretty_onnx(onx))
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 3]
Sin(x) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 3]
Let’s make sure the node was produce was the user defined converter for numpy_sin. The name should be ‘mylib.numpy_sin’.
print(onx.graph.node[0])
input: "x"
output: "output_0"
name: "mylib.numpy_sin"
op_type: "Sin"
domain: ""
And visually.
<Axes: >
Total running time of the script: (0 minutes 0.514 seconds)