Torch Export to ONNX#

Note

yobx.torch.interpreter.to_onnx() is not torch.onnx.export(). See Not torch.onnx.export for a detailed comparison.

This section describes the design of the PyTorch-to-ONNX conversion pipeline. The entry point is yobx.torch.interpreter.to_onnx(), which accepts a torch.nn.Module and representative inputs and returns an onnx.ModelProto (or an onnx.model_container.ModelContainer for large models).

The pipeline has three main stages:

  1. Export — the module is traced into a portable ExportedProgram (or GraphModule) using one of the strategies provided by ExportOptions (strict, nostrict, tracing, jit, dynamo, fake, …).

  2. InterpretDynamoInterpreter walks the FX graph node by node and emits the corresponding ONNX operators into a GraphBuilder.

  3. Optimise — the accumulated ONNX graph is folded, simplified, and serialised by to_onnx().

The remaining pages in this section document supporting concerns: how custom pytree nodes must be registered before export (Flattening Functionalities (torch)), which internal torch/transformers patches are needed for successful tracing (Patches (torch export)), and how real forward passes can be used to infer export arguments and dynamic shapes automatically (InputObserver).