Tutorial¶
This module was started to experiment around function torch.export.export()
and see what kind of issues occur when leveraging that function to convert
a torch.nn.Module
into ONNX.
The tutorial is a collection of examples or benchmark around that topic.
Section Design explains how a converter works from torch model
to the onnx graph. The official exporter is implemented in pytorch
itself through the function torch.onnx.export()
.
Next sections show many examples, including how to deal with some possible issues.
torch.export.export: export to a Graph¶
All exporters rely on function torch.export.export()
to convert
a pytorch module into a torch.fx.Graph
. Only then the conversion
to ONNX starts. Most of the issues come from this first step and it is
convenient to understand what it does. pytorch documentation
already has many examples about it. Here are some corner cases.
Dynamic Shapes¶
strict = ?¶
The parameter strict of torch.export.export()
usually has no impact
except in some rare cases.
The exporter relies on torch.export.export()
. It exposes a parameter called
strict: bool = True (true by default).
The behaviour is different in some specific configuration.
torch.ops.higher_order.scan
torch.ops.higher_order.scan()
is a way to export a model with a loop.
Not all signatures work with this mode.
Here is an example with scan.
<<<
import torch
def add(carry: torch.Tensor, y: torch.Tensor):
next_carry = carry + y
return [next_carry, next_carry]
class ScanModel(torch.nn.Module):
def forward(self, x):
init = torch.zeros_like(x[0])
carry, out = torch.ops.higher_order.scan(
add, [init], [x], dim=0, reverse=False, additional_inputs=[]
)
return carry
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32)
model = ScanModel()
expected = model(x)
print("------")
print(expected, x.sum(axis=0))
print("------ strict=False")
print(torch.export.export(model, (x,), strict=False).graph)
print("------ strict=True")
print(torch.export.export(model, (x,), strict=True).graph)
>>>
------
tensor([12., 15., 18.]) tensor([12., 15., 18.])
------ strict=False
graph():
%x : [num_users=2] = placeholder[target=x]
%select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
%zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
%scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
%scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
%getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
return (getitem,)
------ strict=True
graph():
%x : [num_users=2] = placeholder[target=x]
%select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
%zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
%scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
%scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
%getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
return (getitem,)
inplace x[…, i] = y
This expression cannot be captured with strict=False
.
<<<
import torch
class UpdateModel(torch.nn.Module):
def forward(
self, x: torch.Tensor, update: torch.Tensor, kv_index: torch.LongTensor
):
x = x.clone()
x[..., kv_index] = update
return x
example_inputs = (
torch.ones((4, 4, 10)).to(torch.float32),
(torch.arange(2) + 10).to(torch.float32).reshape((1, 1, 2)),
torch.Tensor([1, 2]).to(torch.int32),
)
model = UpdateModel()
try:
torch.export.export(model, (x,), strict=False)
except Exception as e:
print(e)
>>>
name 'x' is not defined
torch.onnx.export: export to ONNX¶
These examples relies on torch.onnx.export()
.
Simple Case¶
Control Flow¶
Custom Operators¶
l-plot-exporter-recipes-onnx-exporter-custom-ops-fct
l-plot-exporter-recipes-onnx-exporter-custom-ops-inplace
Submodules¶
l-plot-exporter-recipes-onnx-exporter-modules
Models¶
Optimization¶
Supported Scenarios¶
The following pages explores many kind of signatures for a forward method and how they translate into ONNX when they can. The result are summarized by the following pages. It tries model taking tensors, list of tensors, integers or floats. It also tries test and loops.
Frequent Exceptions or Errors with the Exporter¶
Unsupported functions or classes¶
If the converter to onnx fails, function bypass_export_some_errors
may help solving some of them. The ocumentation of this function
gives the list of issues it can bypass.
from experimental_experiment.torch_interpreter.onnx_export_errors import (
bypass_export_some_errors,
)
with bypass_export_some_errors():
# export to onnx with (model, inputs, ...)
If the input contains a cache class, you may need to patch the inputs.
from experimental_experiment.torch_interpreter.onnx_export_errors import (
bypass_export_some_errors,
)
with bypass_export_some_errors(patch_transformers=True) as modificator:
inputs = modificator(inputs)
# export to onnx with (model, inputs, ...)
This function is a work in progress as the exporter extends the list of supported models. A standaline copy of this function can be found at phi35.
torch._dynamo.exc.Unsupported¶
torch._dynamo.exc.Unsupported: call_function BuiltinVariable(NotImplementedError) [ConstantVariable()] {}
This exception started to show up with transformers==4.38.2 but it does not seem related to it. Wrapping the code with the following fixes it.
with torch.no_grad():
# ...
RuntimeError¶
RuntimeError: Encountered autograd state manager op <built-in function _set_grad_enabled> trying to change global autograd state while exporting.
Wrapping the code around probably solves this issue.
with torch.no_grad():
# ...
Play with onnx models and onnxruntime¶
onnxscript is one way to directly create model or function in ONNX. The onnxscript Tutorial explains how it works. Some other examples follow.
An exported model can be slow. It can be profiled on CUDA with the native profiling NVIDIA built. It can also be profiled with the tool implemented in onnxruntime. Next example shows that on CPU.
Deeper into pytorch and onnx¶
101¶
102¶
201¶
301¶
to_onnx: another export to investigate¶
to_onnx
implements
another exporter to ONNX. It does not support all the cases torch.onnx.export()
.
It fails rather trying different options to recover.
It calls torch.export.export()
but does not alter the graph
(no rewriting, no decomposition) before converting this graph to onnx.
It is used to investigate export issues raised by torch.export.export()
.
Simple Case¶
Control Flow¶
Custom Operators¶
Submodules¶
Model¶
Optimization¶
Dockers¶
Old work used to play with torch.compile()
on a docker.