Patches (torch export)#
Note
This section covers functionality that is specific to PyTorch.
It is only relevant when exporting torch.nn.Module models with
torch.export.export() and has no bearing on ONNX models built
directly with the builder APIs.
Before exporting a torch.nn.Module with torch.export.export(),
certain internal functions in torch and transformers must be
temporarily replaced with corrected versions to avoid crashes or incorrect
behaviour during tracing. The patching infrastructure in
yobx.helpers.patch_helper and yobx.torch.patch provides a
structured, reversible way to apply, track and report such replacements.
Why patches are needed#
torch.export.export() symbolically traces a model, meaning that every
operation is recorded with symbolic dimension variables rather than concrete
sizes. Several low-level torch and transformers helpers were not written to
cope with symbolic dimensions and raise exceptions or silently return wrong
results when they encounter them.
Control flow is the primary source of failures#
The most fundamental issue is Python control flow inside traced functions.
During symbolic tracing, dimension values are torch.SymInt objects
rather than plain integers. Any code written as:
if condition_on_size:
raise SomeError(...)
will crash the exporter because evaluating the if forces the symbolic
integer to a concrete boolean — a SymBool — which cannot be
resolved at trace time, raising GuardOnDataDependentSymNode.
The fix is to replace such guards with torch._check():
torch._check(condition_on_size)
torch._check() is understood by the exporter and recorded as a
constraint on the symbolic value instead of executing a branch. The patches
shipped with this library systematically replace if ... : raise patterns
in torch and transformers internals with torch._check(...) equivalents so
that symbolic tracing can proceed without crashing.
Other common failure modes#
GuardOnDataDependentSymNode— a symbolic size cannot be evaluated to a concrete boolean at trace time, crashing broadcasting helpers such asinfer_sizeor_broadcast_shapes.Incorrect range-constraint computation —
_get_range_constraintsuses the wrong argument order, causing dynamic shape constraints to be mis-assigned.Wrapped
RotaryEmbedding.forward— some transformers model variants wrapRotaryEmbedding.forwardwith a decorator that introduces control-flow invisible to the exporter; the wrapper must be replaced with a traceable version.
Rather than forking torch or transformers, patches swap in corrected
implementations only for the duration of the export, then restore the originals.
On that particular topic, yobx.torch.tiny_models.TinyBroadcastAddModel
illustrates what problem a user can run into.
The exporter makes a strong assumption when it comes to infer
shapes after a broadcast. Two dimensions are broadcastable
if they are equal or one of them is equal to 1. But sometimes,
this assumption cannot be verified in a sense the result is not known.
In that condition, it is reasonable to assume the model is correct
and the final dimension is the maximum of the two broadcasted dimensions.
The patches enables that to be possible as the example
yobx.torch.tiny_models.TinyBroadcastAddModel demonstrates.
Core data structures#
PatchInfo#
PatchInfo describes a single patch:
patch — the replacement callable.
do / undo — lambdas that swap the function in and out of the target module or class.
family — a free-form category string (
"torch"or"transformers") used for reporting.depends_on — a list of
PatchInfoobjects that must also be applied for this patch to work correctly.
The convenience constructor make()
covers the common pattern of replacing an attribute on a module or class:
<<<
import torch
import torch._refs
from yobx.helpers.patch_helper import PatchInfo
def my_patched_fn(*shapes):
"""Minimal stand-in that just returns the first non-empty shape."""
for s in shapes:
if s:
return list(s)
return []
patch = PatchInfo.make(
my_patched_fn,
torch._refs,
"_broadcast_shapes",
family="torch",
)
# Apply the patch.
patch.do()
print("active patch:", patch.name)
print("patched function:", torch._refs._broadcast_shapes.__name__)
# Restore the original.
patch.undo()
print("after undo:", torch._refs._broadcast_shapes.__name__)
>>>
active patch: my_patched_fn
patched function: my_patched_fn
after undo: _broadcast_shapes
make_diff() produces a unified diff
between the original function and the patched replacement. Calling this for
every active patch (or via
make_report()) is important
because it gives an exhaustive picture of what changed — including patches
that are only reachable through a module method (e.g. a RotaryEmbedding
variant whose forward is called indirectly from the top-level
nn.Module). Without the full diff it is easy to overlook a patched
sub-function and miss the reason a particular graph node was introduced:
import torch
import torch._refs
from yobx.helpers.patch_helper import PatchInfo
def my_patched_fn(*shapes):
"""Returns the first non-empty shape."""
for s in shapes:
if s:
return list(s)
return []
patch = PatchInfo.make(
my_patched_fn,
torch._refs,
"_broadcast_shapes",
family="torch",
)
patch.do()
diff = patch.make_diff()
patch.undo()
# Print the first few lines of the diff.
print("\n".join(diff.splitlines()[:10]))
That gives:
--- original
+++ rewritten
@@ -1,77 +1,6 @@
-def _broadcast_shapes(*_shapes):
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- has_hint,
- is_nested_int,
- size_hint,
- )
...
Or with diff = patch.format_diff("rst"):
.. _patch-torch-my_patched_fn:
torch: _broadcast_shapes -> my_patched_fn
=========================================
.. code-block:: diff
:linenos:
--- original
+++ rewritten
@@ -1,77 +1,6 @@
-def _broadcast_shapes(*_shapes):
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- has_hint,
- is_nested_int,
- size_hint,
- )
-
- backed_so = torch.fx.experimental._config.backed_size_oblivious
...
PatchDetails#
PatchDetails is an ordered collection of
PatchInfo objects. It is the object
yielded by the apply_patches_for_model() context
manager so that callers can inspect every patch that was active during export.
Useful methods:
find()— look up a patch by function name.patches_involved_in_graph()— identify which patches contributed nodes to a graph by cross-referencingnode.meta["stack_trace"]with each patch’s source location. The method is designed fortorch.fx.Graphbut accepts any object that provides:graph.nodes— an iterable of node objects;node.meta— adicton each node;node.meta["stack_trace"]— the call-stack string captured when the node was created.
Any custom graph representation that exposes these three attributes works equally well.
make_report()— render a human-readable report (raw text or RST) listing each involved patch together with its diff and the affected graph nodes.
Applying patches#
The recommended entry point is the context manager
apply_patches_for_model():
from yobx.torch import apply_patches_for_model
with apply_patches_for_model(
patch_torch=True,
patch_transformers=True,
model=model,
verbose=1,
) as details:
ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
# After the `with` block every patch has been restored.
for patch in details:
print(patch.format_diff(format="raw"))
Parameters:
patch_torch — applies the four patches in
yobx.torch.in_torch.patchesthat fix symbolic-shape handling insidetorch.patch_transformers — applies the
RotaryEmbeddingpatches fromyobx.torch.in_transformers.patches. When model is provided, the function inspects each sub-module for wrappedRotaryEmbedding.forwardimplementations and adds a model-specific patch automatically.model — the
torch.nn.Modulebeing exported. Supplying it enables automatic detection of non-standardRotaryEmbeddingvariants.verbose — print each patch name as it is applied or removed.
Shipped patches#
torch patches#
The following patches are registered in yobx.torch.in_torch.patches
(family "torch"):
Patch function |
What it fixes |
|---|---|
|
|
|
|
|
|
|
|
transformers patches#
get_patches_for() returns patches for
the transformers library. When a model is supplied the function walks its
sub-modules and adds a model-specific patch for every RotaryEmbedding whose
forward method has been wrapped (common in recent transformers releases
where dynamic_rope_update is applied via a decorator).
The RotaryEmbedding patch replaces the wrapped forward with
common_RotaryEmbedding.forward from
yobx.torch.in_transformers._patches_model_rope_utils, and adds a
dependency patch that replaces
transformers.modeling_rope_utils.dynamic_rope_update with a traceable
version.
List of Patches#
See Patches List.
See also
Flattening Functionalities (torch) — registering pytree nodes for DynamicCache
and other transformers classes, which is typically used alongside patches.
yobx.helpers.patch_helper — PatchInfo and
PatchDetails API reference.
yobx.torch.patch — apply_patches_for_model() API reference.
yobx.torch.in_torch.patches — torch patch implementations.
yobx.torch.in_transformers.patches — transformers patch
implementations.