Torch Converter to ONNX#

Note

This section covers the core conversion pipeline that transforms a torch.nn.Module into an onnx.ModelProto. It is only relevant when exporting PyTorch models and has no bearing on ONNX models built directly with the builder APIs.

Not torch.onnx.export#

yobx.torch.interpreter.to_onnx() is not torch.onnx.export(). They share a similar high-level goal — converting a torch.nn.Module to an onnx.ModelProto — but are otherwise independent implementations with different design priorities.

yobx vs torch.onnx.export#

Aspect

yobx.torch.interpreter.to_onnx()

torch.onnx.export() (dynamo backend)

ATen → ONNX translation

Custom DynamoInterpreter that walks the FX graph node-by-node and emits ONNX ops directly into a GraphBuilder

Uses onnxscript as the intermediate representation for ATen-to-ONNX mappings

Export strategy

Multiple strategies selectable via ExportOptions: strict, nostrict, tracing, fake, … The tracing strategy uses CustomTracer (symbolic tracing) and can handle models where torch.export.export fails.

Dynamo-based (torch.export.export or torch._dynamo.export)

Graph decomposition

The graph is not decomposed by default (unless decomposition_table is set in ExportOptions). Each ATen op is translated directly, which keeps the graph compact but means that a new op introduced by PyTorch will raise a NotImplementedError until a dedicated converter is added.

Always decomposes the graph into a fixed set of core ATen ops, making the exporter robust to new high-level ops at the cost of larger, less readable ONNX graphs

Pattern Optimization

Supports any kind of pattern matching, multiple outputs, variable number of nodes, outputs, types.

Supports a fixed number of outputs.

Shape & type propagation

Every intermediate result is typed and (when possible) given a concrete or symbolic shape inside the builder, enabling richer optimization and easier debugging

Shape information comes from the traced graph; no additional in-builder propagation step

Large-model support

large_model=True returns an onnx.model_container.ModelContainer; weights can be stored as external data files via external_threshold

External data handled by the caller after export

Debugging

Rich set of environment variables (ONNXSTOP, ONNXSTOPSHAPE, ONNXSTOPTYPE, ONNXSTOPOUTPUT, …) let you pinpoint exactly which node assigns a suspicious shape or type; see Debugging when Exporting with GraphBuilder and Debugging

Standard Python / PyTorch debugging tools

In short: if you need the officially supported PyTorch exporter, use torch.onnx.export(). If you need finer control over the ATen → ONNX translation, built-in graph optimization, advanced export strategies such as fake-tensor mode or automatic dynamic-shape inference, or rich debugging capabilities (Debugging when Exporting with GraphBuilder, Debugging), yobx.torch.interpreter.to_onnx() may be a better fit.

The entry point for converting a PyTorch model to ONNX is yobx.torch.interpreter.to_onnx(). The function orchestrates a multi-stage pipeline:

  1. Export — trace the module into a portable graph representation using torch.export.export() (or one of its alternatives).

  2. Interpret — walk every node of the FX graph and translate it into a sequence of ONNX operations via DynamoInterpreter.

  3. Optimise — run the ONNX graph through the optimiser shipped in GraphBuilder to fold constants, remove redundant casts, and simplify shapes.

Pipeline overview#

torch.nn.Module
      │
      │  ExportOptions  (strict / nostrict / tracing / jit / dynamo …)
      ▼
torch.export.ExportedProgram   ←─── torch.fx.GraphModule (optional)
      │
      │  _make_builder_interpreter()
      ▼
GraphBuilder  +  DynamoInterpreter
      │
      │  DynamoInterpreter.run()  ── node-by-node dispatch
      │        ├── placeholder   →  graph input
      │        ├── get_attr      →  initializer
      │        ├── call_function →  aten_* converter
      │        ├── call_method   →  aten method converter
      │        ├── call_module   →  submodule (recursive)
      │        └── output        →  graph output
      ▼
GraphBuilder (ONNX ops accumulated)
      │
      │  optimize=True  →  OptimizationOptions applied
      ▼
onnx.ModelProto  (or ModelContainer for large models)

Key components#

to_onnx#

yobx.torch.interpreter.to_onnx() is the public API. Its most important parameters are:

  • mod — the torch.nn.Module to export.

  • args / kwargs — representative inputs (used to infer shapes and to validate the export when validate_onnx is set).

  • dynamic_shapes — a nested structure that marks which tensor dimensions are symbolic; mirrors the argument accepted by torch.export.export().

  • export_options — an ExportOptions instance (or a short strategy string such as "nostrict-dec") that controls how the model is first exported to an FX graph.

  • target_opset — the ONNX opset version to target (default: the latest supported opset minus one).

  • optimize — when True (default) the generated ONNX graph is optimised before being returned.

  • dispatcher — an optional Dispatcher that can override the default ATen-to-ONNX mapping for specific operators.

Basic usage#

<<<

import torch
from yobx.torch.interpreter import to_onnx


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.relu(self.linear(x))


model = Neuron(5, 3)
x = torch.rand(2, 5)
onx = to_onnx(model, (x,))
print(onx.graph.node[:3])

>>>

    [input: "x"
    input: "GemmTransposePattern--p_linear_weight::T10"
    input: "linear.bias"
    output: "linear"
    name: "GemmTransposePattern--MatMulAddPattern--Opset2"
    op_type: "Gemm"
    attribute {
      name: "transB"
      i: 1
      type: INT
    }
    doc_string: "#Io1.\n-T1:2x3\n#Io1#Io1"
    domain: ""
    metadata_props {
      key: "scope"
      value: "linear"
    }
    metadata_props {
      key: "module[0]"
      value: "sphinx_runpython.runpython.sphinx_runpython_extension.run_python_script_136119210221504.<locals>.Neuron"
    }
    metadata_props {
      key: "module[1]"
      value: "torch.nn.modules.linear.Linear"
    }
    metadata_props {
      key: "source[0]"
      value: "linear.forward"
    }
    metadata_props {
      key: "intypes"
      value: "FLOAT / FLOAT / FLOAT"
    }
    metadata_props {
      key: "outtypes"
      value: "FLOAT"
    }
    metadata_props {
      key: "inshapes"
      value: "(2, 5) / (3, 5) / (3,)"
    }
    metadata_props {
      key: "outshapes"
      value: "(2, 3)"
    }
    , input: "linear"
    output: "output_0"
    name: "relu"
    op_type: "Relu"
    doc_string: "#Io1.\n-T1:2x3\nrelu:1:(2,3)"
    domain: ""
    metadata_props {
      key: "module[0]"
      value: "sphinx_runpython.runpython.sphinx_runpython_extension.run_python_script_136119210221504.<locals>.Neuron"
    }
    metadata_props {
      key: "intypes"
      value: "FLOAT"
    }
    metadata_props {
      key: "outtypes"
      value: "FLOAT"
    }
    metadata_props {
      key: "inshapes"
      value: "(2, 3)"
    }
    metadata_props {
      key: "outshapes"
      value: "(2, 3)"
    }
    ]

ExportOptions#

ExportOptions encapsulates the strategy used to obtain the FX graph from the module. Several named strategies are available:

Strategy string

Meaning

"strict"

torch.export.export with strict=True

"nostrict"

torch.export.export with strict=False (default)

"nostrict-dec"

strict=False + default decomposition table

"nostrict-decall"

strict=False + full decomposition table

"tracing"

symbolic tracing via CustomTracer

"jit"

JIT script → FX graph

"dynamo"

torch._dynamo.export

"dec"

default decomposition table, default strict setting

"decall"

full decomposition table, default strict setting

"fake"

use FakeTensor inputs instead of real tensors

Fake tensors#

When fake=True (or strategy="fake"), the export stage replaces every real input tensor with a FakeTensor — a lightweight stand-in that carries dtype, shape, and device metadata but holds no actual data. This is useful when loading model weights into memory just to trace the graph would be prohibitively expensive (e.g. very large language models).

The conversion from real tensors to fake tensors is handled by make_fake_with_dynamic_dimensions(), which also ensures that dimensions sharing the same name in dynamic_shapes are mapped to the same symbolic integer. The FakeTensorContext manages the underlying FakeTensorMode and the mapping between concrete dimension values and their symbolic counterparts.

DynamoInterpreter#

DynamoInterpreter is the heart of the converter. It walks the torch.fx.Graph node by node and translates each node into one or more ONNX operators appended to the GraphBuilder.

Node kinds and their handlers#

FX node kind

Action taken by the interpreter

placeholder

Registers a graph input with the correct ONNX type and shape.

get_attr

Looks up a weight / buffer / constant via the retriever callable and registers it as an ONNX initializer.

call_function

Looks up the ATen function in the aten-function registry (_aten_functions.py) or delegates to the Dispatcher.

call_method

Looks up the ATen method in the method registry (_aten_methods.py).

call_module

Recursively converts a submodule using a nested DynamoInterpreter instance.

output

Registers graph outputs, applying any output masks produced by the export stage.

ATen-to-ONNX converters#

Each call_function node carries a torch._ops.OpOverload such as aten.relu.default or aten.mm.default. The interpreter resolves this to a Python function whose name follows the pattern aten_<op>_<overload> (e.g. aten_relu_default) defined in one of:

Each converter function has the signature:

def aten_<op>_<overload>(
    g: GraphBuilder,
    sts: Dict[str, Any],   # shape/type state
    outputs: List[str],    # desired output names
    *args,
    **kwargs,
) -> str:
    ...

It appends one or more ONNX nodes to g and returns the name of the primary output tensor.

Dispatcher#

Dispatcher allows callers to override or extend the built-in ATen converter mapping without modifying the library. It is especially useful when a model uses custom ops or when the default conversion of a particular op should be replaced.

from yobx.torch.interpreter import Dispatcher, to_onnx
from yobx.xbuilder import GraphBuilder

def my_relu(g: GraphBuilder, sts, outputs, x):
    return g.op.Relu(x, outputs=outputs)

dispatcher = Dispatcher({"aten_relu_default": my_relu})
onx = to_onnx(model, (x,), dispatcher=dispatcher)

ForceDispatcher is a stricter variant that raises an error when the requested function is not found, making it easier to discover missing converters during development.

GraphBuilder#

The GraphBuilder accumulates ONNX ops and is responsible for:

  • Type and shape propagation — every result tensor is given an ONNX type and, when possible, a concrete or symbolic shape.

  • Optimisation — constant folding, cast elimination, identity removal, and other peephole passes controlled by OptimizationOptions.

  • Serialisation — once all nodes have been appended, to_onnx() serialises the accumulated state into an onnx.ModelProto.

Large models#

When large_model=True the converter returns an onnx.model_container.ModelContainer instead of an onnx.ModelProto. This defers the decision on whether to embed the weights inside the protobuf or to store them as external data files.

The external_threshold parameter (default: 1 024 bytes) controls which initializers are treated as external when the container is later saved.

Dynamic shapes#

The dynamic_shapes argument is forwarded to torch.export.export() and follows the same nested-dict / nested-tuple convention. Symbolic dimension variables should be instances of torch.export.Dim.

use_dyn_not_str() is a convenience helper that replaces string-valued dimension annotations (which some helpers return) with torch.export.Dim.DYNAMIC:

from yobx.torch import use_dyn_not_str
dynamic_shapes = use_dyn_not_str({"x": {0: "batch", 1: "seq"}})

For more complex scenarios — especially LLMs with prefill/decode phases — InputObserver can infer the dynamic_shapes automatically from real forward passes; see InputObserver.

Submodules as local ONNX functions#

When export_modules_as_functions=True (or a set of module types), the converter unfolds the model via torch.export.unflatten() so that each submodule becomes a separate ONNX local function. This preserves the module hierarchy in the ONNX graph and can reduce model size when the same submodule type is reused many times.

The granularity is controlled by the function_options parameter (a FunctionOptions instance) which specifies, among other things, how initializers inside local functions should be represented (as constants inlined into the function body or as additional function inputs).

Validation#

Setting validate_onnx=True (or a float tolerance) causes to_onnx() to run the exported ONNX model with the same inputs that were used to export it and compare the outputs against the PyTorch model’s outputs. AssertionError is raised if the maximum absolute difference exceeds the tolerance (default 1e-5).

Environment variables#

Several environment variables alter the converter’s behaviour without requiring code changes:

Variable

Effect

ONNXVERBOSE=1

Increases verbosity inside to_onnx().

PRINT_GRAPH_MODULE=1

Prints the FX graph before interpretation.

ONNX_BUILDER_PROGRESS=1

Shows a progress bar for large models.

PRINT_EXPORTED_PROGRAM=1

Prints the ExportedProgram before interpretation.

Debugging when Exporting with GraphBuilder#

GraphBuilder reads several environment variables at construction time that raise an exception as soon as a named result is assigned a shape, type, or value. Setting one of these is the fastest way to get a Python traceback pointing at the exact line that produces a suspicious tensor.

Variable

Effect

ONNXSTOP=<name>

Raises when result <name> is created (type or shape assignment). Example: ONNXSTOP=attn_output python script.py

ONNXSTOPSHAPE=<name>

Raises when result <name> receives a shape.

ONNXSTOPTYPE=<name>

Raises when result <name> receives a type.

ONNXSTOPSEQUENCE=<name>

Raises when result <name> is assigned a sequence type.

ONNXSTOPVALUESHAPE=<name>

Enables extra logging in the shape-value computation path for <name>.

ONNXSTOPOUTPUT=<name>

Raises when a node whose output contains <name> is appended to the graph.

ONNXDYNDIM=<name>

Raises when dynamic dimension <name> is referenced.

ONNXCST=1

Logs every constant that is evaluated during shape inference.

ONNXSHAPECOMPUTE=1

Raises when a shape cannot be inferred (instead of silently leaving it unknown).

See also#