Torch Export to ONNX#

The official converter is torch.onnx.export(). This converter has been used to investigate when the first one is failing. It is designed to quickly fail and offers more tracing options to capture the torch.fx.Graph such as symbolic tracing or different decomposition tables. First section exposes the differences. Overview of Exportability Comparison shows some of the differences on basic examples.

Note

yobx.torch.interpreter.to_onnx() is not torch.onnx.export(). See Not torch.onnx.export for a detailed comparison.

This section describes the design of the PyTorch-to-ONNX conversion pipeline. The entry point is yobx.torch.interpreter.to_onnx(), which accepts a torch.nn.Module and representative inputs and returns an onnx.ModelProto (or an onnx.model_container.ModelContainer for large models).

The pipeline has three main stages:

  1. Export — the module is traced into a portable ExportedProgram (or GraphModule) using one of the strategies provided by ExportOptions (strict, nostrict, tracing, jit, dynamo, fake, …).

  2. InterpretDynamoInterpreter walks the FX graph node by node and emits the corresponding ONNX operators into a GraphBuilder.

  3. Optimise — the accumulated ONNX graph is folded, simplified, and serialised by to_onnx().

The remaining pages in this section document supporting concerns: how custom pytree nodes must be registered before export (Flattening Functionalities (torch)), which internal torch/transformers patches are needed for successful tracing (Patches (torch export)), and how real forward passes can be used to infer export arguments and dynamic shapes automatically (InputObserver).