Tutorial

This module was started to experiment around function torch.export.export() and see what kind of issues occur when leveraging that function to convert a torch.nn.Module into ONNX. The tutorial is a collection of examples or benchmark around that topic. Section Design explains how a converter works from torch model to the onnx graph. The official exporter is implemented in pytorch itself through the function torch.onnx.export(). Next sections show many examples, including how to deal with some possible issues.

torch.export.export: export to a Graph

All exporters rely on function torch.export.export() to convert a pytorch module into a torch.fx.Graph. Only then the conversion to ONNX starts. Most of the issues come from this first step and it is convenient to understand what it does. pytorch documentation already has many examples about it. Here are some corner cases.

Dynamic Shapes

strict = ?

The parameter strict of torch.export.export() usually has no impact except in some rare cases.

The exporter relies on torch.export.export(). It exposes a parameter called strict: bool = True (true by default). The behaviour is different in some specific configuration.

torch.ops.higher_order.scan

torch.ops.higher_order.scan() is a way to export a model with a loop. Not all signatures work with this mode. Here is an example with scan.

<<<

import torch


def add(carry: torch.Tensor, y: torch.Tensor):
    next_carry = carry + y
    return [next_carry, next_carry]


class ScanModel(torch.nn.Module):
    def forward(self, x):
        init = torch.zeros_like(x[0])
        carry, out = torch.ops.higher_order.scan(
            add, [init], [x], dim=0, reverse=False, additional_inputs=[]
        )
        return carry


x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32)
model = ScanModel()
expected = model(x)
print("------")
print(expected, x.sum(axis=0))
print("------ strict=False")
print(torch.export.export(model, (x,), strict=False).graph)
print("------ strict=True")
print(torch.export.export(model, (x,), strict=True).graph)

>>>

    ------
    tensor([12., 15., 18.]) tensor([12., 15., 18.])
    ------ strict=False
    graph():
        %x : [num_users=2] = placeholder[target=x]
        %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
        %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
        %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
        %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
        %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
        %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
        return (getitem,)
    ------ strict=True
    graph():
        %x : [num_users=2] = placeholder[target=x]
        %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%x, 0, 0), kwargs = {})
        %zeros_like : [num_users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%select,), kwargs = {pin_memory: False})
        %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
        %scan : [num_users=2] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%zeros_like], [%x], 0, False, []), kwargs = {})
        %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 0), kwargs = {})
        %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
        return (getitem,)

inplace x[…, i] = y

This expression cannot be captured with strict=False.

<<<

import torch


class UpdateModel(torch.nn.Module):
    def forward(
        self, x: torch.Tensor, update: torch.Tensor, kv_index: torch.LongTensor
    ):
        x = x.clone()
        x[..., kv_index] = update
        return x


example_inputs = (
    torch.ones((4, 4, 10)).to(torch.float32),
    (torch.arange(2) + 10).to(torch.float32).reshape((1, 1, 2)),
    torch.Tensor([1, 2]).to(torch.int32),
)

model = UpdateModel()

try:
    torch.export.export(model, (x,), strict=False)
except Exception as e:
    print(e)

>>>

    name 'x' is not defined

torch.onnx.export: export to ONNX

These examples relies on torch.onnx.export().

Simple Case

Linear Regression and export to ONNX

Control Flow

Custom Operators

  • l-plot-exporter-recipes-onnx-exporter-custom-ops-fct

  • l-plot-exporter-recipes-onnx-exporter-custom-ops-inplace

Submodules

  • l-plot-exporter-recipes-onnx-exporter-modules

Models

Optimization

See Pattern-based Rewrite Using Rules With onnxscript.

Supported Scenarios

The following pages explores many kind of signatures for a forward method and how they translate into ONNX when they can. The result are summarized by the following pages. It tries model taking tensors, list of tensors, integers or floats. It also tries test and loops.

Frequent Exceptions or Errors with the Exporter

Unsupported functions or classes

If the converter to onnx fails, function bypass_export_some_errors may help solving some of them. The ocumentation of this function gives the list of issues it can bypass.

from experimental_experiment.torch_interpreter.onnx_export_errors import (
    bypass_export_some_errors,
)

with bypass_export_some_errors():
    # export to onnx with (model, inputs, ...)

If the input contains a cache class, you may need to patch the inputs.

from experimental_experiment.torch_interpreter.onnx_export_errors import (
    bypass_export_some_errors,
)

with bypass_export_some_errors(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    # export to onnx with (model, inputs, ...)

This function is a work in progress as the exporter extends the list of supported models. A standaline copy of this function can be found at phi35.

torch._dynamo.exc.Unsupported

torch._dynamo.exc.Unsupported: call_function BuiltinVariable(NotImplementedError) [ConstantVariable()] {}

This exception started to show up with transformers==4.38.2 but it does not seem related to it. Wrapping the code with the following fixes it.

with torch.no_grad():
    # ...

RuntimeError

RuntimeError: Encountered autograd state manager op <built-in function _set_grad_enabled> trying to change global autograd state while exporting.

Wrapping the code around probably solves this issue.

with torch.no_grad():
    # ...

Play with onnx models and onnxruntime

onnxscript is one way to directly create model or function in ONNX. The onnxscript Tutorial explains how it works. Some other examples follow.

An exported model can be slow. It can be profiled on CUDA with the native profiling NVIDIA built. It can also be profiled with the tool implemented in onnxruntime. Next example shows that on CPU.

Deeper into pytorch and onnx

101

102

201

301

to_onnx: another export to investigate

to_onnx implements another exporter to ONNX. It does not support all the cases torch.onnx.export(). It fails rather trying different options to recover. It calls torch.export.export() but does not alter the graph (no rewriting, no decomposition) before converting this graph to onnx. It is used to investigate export issues raised by torch.export.export().

Simple Case

Control Flow

Custom Operators

Submodules

Model

Optimization

Dockers

Old work used to play with torch.compile() on a docker.