Torch Converter to ONNX#
Note
This section covers the core conversion pipeline that transforms a
torch.nn.Module into an onnx.ModelProto. It is only
relevant when exporting PyTorch models and has no bearing on ONNX models
built directly with the builder APIs.
Not torch.onnx.export#
yobx.torch.interpreter.to_onnx() is not torch.onnx.export().
They share a similar high-level goal — converting a torch.nn.Module to
an onnx.ModelProto — but are otherwise independent implementations with
different design priorities.
Aspect |
|
|
|---|---|---|
ATen → ONNX translation |
Custom |
Uses onnxscript as the intermediate representation for ATen-to-ONNX mappings |
Export strategy |
Multiple strategies selectable via
|
Dynamo-based ( |
Graph decomposition |
The graph is not decomposed by default (unless
|
Always decomposes the graph into a fixed set of core ATen ops, making the exporter robust to new high-level ops at the cost of larger, less readable ONNX graphs |
Pattern Optimization |
Supports any kind of pattern matching, multiple outputs, variable number of nodes, outputs, types. |
Supports a fixed number of outputs. |
Shape & type propagation |
Every intermediate result is typed and (when possible) given a concrete or symbolic shape inside the builder, enabling richer optimization and easier debugging |
Shape information comes from the traced graph; no additional in-builder propagation step |
Large-model support |
|
External data handled by the caller after export |
Debugging |
Rich set of environment variables ( |
Standard Python / PyTorch debugging tools |
In short: if you need the officially supported PyTorch exporter, use
torch.onnx.export(). If you need finer control over the
ATen → ONNX translation, built-in graph optimization, advanced export
strategies such as fake-tensor mode or automatic dynamic-shape inference,
or rich debugging capabilities (Debugging when Exporting with GraphBuilder,
Debugging),
yobx.torch.interpreter.to_onnx() may be a better fit.
The entry point for converting a PyTorch model to ONNX is
yobx.torch.interpreter.to_onnx(). The function orchestrates a
multi-stage pipeline:
Export — trace the module into a portable graph representation using
torch.export.export()(or one of its alternatives).Interpret — walk every node of the FX graph and translate it into a sequence of ONNX operations via
DynamoInterpreter.Optimise — run the ONNX graph through the optimiser shipped in
GraphBuilderto fold constants, remove redundant casts, and simplify shapes.
Pipeline overview#
torch.nn.Module
│
│ ExportOptions (strict / nostrict / tracing / jit / dynamo …)
▼
torch.export.ExportedProgram ←─── torch.fx.GraphModule (optional)
│
│ _make_builder_interpreter()
▼
GraphBuilder + DynamoInterpreter
│
│ DynamoInterpreter.run() ── node-by-node dispatch
│ ├── placeholder → graph input
│ ├── get_attr → initializer
│ ├── call_function → aten_* converter
│ ├── call_method → aten method converter
│ ├── call_module → submodule (recursive)
│ └── output → graph output
▼
GraphBuilder (ONNX ops accumulated)
│
│ optimize=True → OptimizationOptions applied
▼
onnx.ModelProto (or ModelContainer for large models)
Key components#
to_onnx#
yobx.torch.interpreter.to_onnx() is the public API. Its most important
parameters are:
mod— thetorch.nn.Moduleto export.args/kwargs— representative inputs (used to infer shapes and to validate the export whenvalidate_onnxis set).dynamic_shapes— a nested structure that marks which tensor dimensions are symbolic; mirrors the argument accepted bytorch.export.export().export_options— anExportOptionsinstance (or a short strategy string such as"nostrict-dec") that controls how the model is first exported to an FX graph.target_opset— the ONNX opset version to target (default: the latest supported opset minus one).optimize— whenTrue(default) the generated ONNX graph is optimised before being returned.dispatcher— an optionalDispatcherthat can override the default ATen-to-ONNX mapping for specific operators.
Basic usage#
<<<
import torch
from yobx.torch.interpreter import to_onnx
class Neuron(torch.nn.Module):
def __init__(self, n_dims: int, n_targets: int):
super().__init__()
self.linear = torch.nn.Linear(n_dims, n_targets)
def forward(self, x):
return torch.relu(self.linear(x))
model = Neuron(5, 3)
x = torch.rand(2, 5)
onx = to_onnx(model, (x,))
print(onx.graph.node[:3])
>>>
[input: "x"
input: "GemmTransposePattern--p_linear_weight::T10"
input: "linear.bias"
output: "linear"
name: "GemmTransposePattern--MatMulAddPattern--Opset2"
op_type: "Gemm"
attribute {
name: "transB"
i: 1
type: INT
}
doc_string: "#Io1.\n-T1:2x3\n#Io1#Io1"
domain: ""
metadata_props {
key: "scope"
value: "linear"
}
metadata_props {
key: "module[0]"
value: "sphinx_runpython.runpython.sphinx_runpython_extension.run_python_script_136119210221504.<locals>.Neuron"
}
metadata_props {
key: "module[1]"
value: "torch.nn.modules.linear.Linear"
}
metadata_props {
key: "source[0]"
value: "linear.forward"
}
metadata_props {
key: "intypes"
value: "FLOAT / FLOAT / FLOAT"
}
metadata_props {
key: "outtypes"
value: "FLOAT"
}
metadata_props {
key: "inshapes"
value: "(2, 5) / (3, 5) / (3,)"
}
metadata_props {
key: "outshapes"
value: "(2, 3)"
}
, input: "linear"
output: "output_0"
name: "relu"
op_type: "Relu"
doc_string: "#Io1.\n-T1:2x3\nrelu:1:(2,3)"
domain: ""
metadata_props {
key: "module[0]"
value: "sphinx_runpython.runpython.sphinx_runpython_extension.run_python_script_136119210221504.<locals>.Neuron"
}
metadata_props {
key: "intypes"
value: "FLOAT"
}
metadata_props {
key: "outtypes"
value: "FLOAT"
}
metadata_props {
key: "inshapes"
value: "(2, 3)"
}
metadata_props {
key: "outshapes"
value: "(2, 3)"
}
]
ExportOptions#
ExportOptions encapsulates the strategy
used to obtain the FX graph from the module. Several named strategies are
available:
Strategy string |
Meaning |
|---|---|
|
|
|
|
|
|
|
|
|
symbolic tracing via |
|
JIT script → FX graph |
|
|
|
default decomposition table, default strict setting |
|
full decomposition table, default strict setting |
|
use |
Fake tensors#
When fake=True (or strategy="fake"), the export stage replaces every
real input tensor with a FakeTensor
— a lightweight stand-in that carries dtype, shape, and device metadata but
holds no actual data. This is useful when loading model weights into memory
just to trace the graph would be prohibitively expensive (e.g. very large
language models).
The conversion from real tensors to fake tensors is handled by
make_fake_with_dynamic_dimensions(), which
also ensures that dimensions sharing the same name in dynamic_shapes are
mapped to the same symbolic integer. The
FakeTensorContext manages the
underlying FakeTensorMode and the
mapping between concrete dimension values and their symbolic counterparts.
DynamoInterpreter#
DynamoInterpreter is the heart
of the converter. It walks the torch.fx.Graph node by node and
translates each node into one or more ONNX operators appended to the
GraphBuilder.
Node kinds and their handlers#
FX node kind |
Action taken by the interpreter |
|---|---|
|
Registers a graph input with the correct ONNX type and shape. |
|
Looks up a weight / buffer / constant via the retriever callable and registers it as an ONNX initializer. |
|
Looks up the ATen function in the aten-function registry
( |
|
Looks up the ATen method in the method registry
( |
|
Recursively converts a submodule using a nested
|
|
Registers graph outputs, applying any output masks produced by the export stage. |
ATen-to-ONNX converters#
Each call_function node carries a torch._ops.OpOverload such as
aten.relu.default or aten.mm.default. The interpreter resolves this
to a Python function whose name follows the pattern aten_<op>_<overload>
(e.g. aten_relu_default) defined in one of:
yobx.torch.interpreter._aten_functions— the bulk of ATen ops.yobx.torch.interpreter._aten_methods— tensor methods (e.g..view,.contiguous).yobx.torch.interpreter._aten_functions_attention— attention-related ops (scaled_dot_product_attentionetc.).yobx.torch.interpreter._prims_functions—torch.ops.prims.*primitives.yobx.torch.interpreter._math_functions—torch.ops.higher_orderand math helpers.yobx.torch.interpreter._non_aten_functions— custom and non-ATen ops.
Each converter function has the signature:
def aten_<op>_<overload>(
g: GraphBuilder,
sts: Dict[str, Any], # shape/type state
outputs: List[str], # desired output names
*args,
**kwargs,
) -> str:
...
It appends one or more ONNX nodes to g and returns the name of the
primary output tensor.
Dispatcher#
Dispatcher allows callers to override or
extend the built-in ATen converter mapping without modifying the library.
It is especially useful when a model uses custom ops or when the default
conversion of a particular op should be replaced.
from yobx.torch.interpreter import Dispatcher, to_onnx
from yobx.xbuilder import GraphBuilder
def my_relu(g: GraphBuilder, sts, outputs, x):
return g.op.Relu(x, outputs=outputs)
dispatcher = Dispatcher({"aten_relu_default": my_relu})
onx = to_onnx(model, (x,), dispatcher=dispatcher)
ForceDispatcher is a stricter variant that
raises an error when the requested function is not found, making it easier to
discover missing converters during development.
GraphBuilder#
The GraphBuilder accumulates ONNX ops and is
responsible for:
Type and shape propagation — every result tensor is given an ONNX type and, when possible, a concrete or symbolic shape.
Optimisation — constant folding, cast elimination, identity removal, and other peephole passes controlled by
OptimizationOptions.Serialisation — once all nodes have been appended,
to_onnx()serialises the accumulated state into anonnx.ModelProto.
Large models#
When large_model=True the converter returns an
onnx.model_container.ModelContainer instead of an
onnx.ModelProto. This defers the decision on whether to embed the
weights inside the protobuf or to store them as external data files.
The external_threshold parameter (default: 1 024 bytes) controls which
initializers are treated as external when the container is later saved.
Dynamic shapes#
The dynamic_shapes argument is forwarded to torch.export.export()
and follows the same nested-dict / nested-tuple convention. Symbolic
dimension variables should be instances of torch.export.Dim.
use_dyn_not_str() is a convenience helper that replaces
string-valued dimension annotations (which some helpers return) with
torch.export.Dim.DYNAMIC:
from yobx.torch import use_dyn_not_str
dynamic_shapes = use_dyn_not_str({"x": {0: "batch", 1: "seq"}})
For more complex scenarios — especially LLMs with prefill/decode phases —
InputObserver can infer the
dynamic_shapes automatically from real forward passes; see
InputObserver.
Submodules as local ONNX functions#
When export_modules_as_functions=True (or a set of module types), the
converter unfolds the model via torch.export.unflatten() so that each
submodule becomes a separate ONNX local function. This preserves the
module hierarchy in the ONNX graph and can reduce model size when the same
submodule type is reused many times.
The granularity is controlled by the function_options parameter (a
FunctionOptions instance) which specifies, among
other things, how initializers inside local functions should be represented
(as constants inlined into the function body or as additional function
inputs).
Validation#
Setting validate_onnx=True (or a float tolerance) causes
to_onnx() to run the exported ONNX model with the same inputs that
were used to export it and compare the outputs against the PyTorch model’s
outputs. AssertionError is raised if the maximum absolute
difference exceeds the tolerance (default 1e-5).
Environment variables#
Several environment variables alter the converter’s behaviour without requiring code changes:
Variable |
Effect |
|---|---|
|
Increases verbosity inside |
|
Prints the FX graph before interpretation. |
|
Shows a progress bar for large models. |
|
Prints the |
Debugging when Exporting with GraphBuilder#
GraphBuilder reads several environment variables at
construction time that raise an exception as soon as a named result is
assigned a shape, type, or value. Setting one of these is the fastest way to
get a Python traceback pointing at the exact line that produces a suspicious
tensor.
Variable |
Effect |
|---|---|
|
Raises when result |
|
Raises when result |
|
Raises when result |
|
Raises when result |
|
Enables extra logging in the shape-value computation path for |
|
Raises when a node whose output contains |
|
Raises when dynamic dimension |
|
Logs every constant that is evaluated during shape inference. |
|
Raises when a shape cannot be inferred (instead of silently leaving it unknown). |
See also#
Flattening Functionalities (torch) — registering custom pytree nodes before export.
Patches (torch export) — patching torch / transformers internals for successful symbolic tracing.
InputObserver — automatic inference of export arguments and dynamic shapes.
Debugging when Exporting with GraphBuilder — environment variables for tracing converter issues (this page).
Debugging —
GraphBuilderdebugging environment variables.yobx.xbuilder.GraphBuilder— the underlying ONNX graph builder.yobx.torch.export_options.ExportOptions— all export strategy options.yobx.torch.interpreter.to_onnx()— the public conversion API.