Patches Explained¶
Function onnx_diagnostic.torch_export_patches.torch_export_patches()
implements four kinds of patches to make it easier to export a model, usually
coming from transformers.
All patches takes place in onnx_diagnostic.torch_export_patches
.
Four Kinds of Patches¶
with torch_export_patches(...) as f:
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
torch fixes: it disables some exceptions or improves some functions related to dynamic shapes until torch addresses the issues (see mostly exporter issues)
transformers rewriting: some methods are replaced with a version
torch.export.export()
can understand, some rewriting may migrate to transformers, others are applied only at export time because it would make the implementation less efficientcache serialization:
torch.export.export()
needs to know how to serialize custom classes such astransformers.cache_utils.DynamicCache
control flow rewriting: control flow (if, for) cannot be exported as is, there is still some work to be done to automatically process them, this package offers some automated rewriting, but it is far from being perfect.
All of them are triggered by onnx_diagnostic.torch_export_patches.torch_export_patches()
.
python -m onnx_diagnostic validate \
-m hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration \
--run -v 1 --export onnx-dynamo -o dump_test --dtype float16 --device cuda
All patches can be disabled with with torch_export_patches(patch=False)
.
torch fixes¶
Implemented in onnx_diagnostic.torch_export_patches.patches.patch_torch
and triggered with
with torch_export_patches(patch_sympy=True, patch_torch=True, catch_constraints=True, stop_if_static=1...)
.
It fixes some issues found while exporting model. Some of them might not be needed anymore.
It improves shape broadcasting or inserts an exception every time a dynamic dimension
becomes static (stop_if_static=1
).
transformers rewriting¶
Implemented in onnx_diagnostic.torch_export_patches.patches.patch_transformers
and triggered with
with torch_export_patches(patch_transformers=True)
.
Every patched class is prefixed with patched_
. It contains two class attributes.
_PATCHES_
contains the list of methods to replace.
_PATCHED_CLASS_
is the class patched by this one.
class patched_AttentionMaskConverter:
"""
Patches
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
"""
# This method was fixed in 4.51 at least.
_PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
_PATCHED_CLASS_ = AttentionMaskConverter
The packages automatically parses this file to extract the patched methods.
More can be added by populating the argument custom_patches
:
with torch_export_patches(patch_transformers=True, custom_patches=[...])
.
Here is the list of available patches:
<<<
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p
for name, cls in p.__dict__.items():
if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
print(
f"{cls._PATCHED_CLASS_.__name__}: "
f"{', '.join([_ for _ in cls._PATCHES_ if _ is not None])}"
)
>>>
DynamicLayer: lazy_initialization
AttentionMaskConverter:
GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting
GemmaRotaryEmbedding: forward
Gemma2RotaryEmbedding: forward
Gemma3RotaryEmbedding: forward
LlamaRotaryEmbedding: forward
MistralRotaryEmbedding: forward
MixtralRotaryEmbedding: forward
PhiRotaryEmbedding: forward
Phi3RotaryEmbedding: forward
Phi4MultimodalRotaryEmbedding: forward
SmolLM3RotaryEmbedding: forward
IdeficsEmbedding: forward
IdeficsAttention: forward
SamMaskDecoder: forward
VisionAttention: forward
Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
Gemma3Model: get_placeholder_mask
Cache serialization¶
Implemented in onnx_diagnostic.torch_export_patches.onnx_export_serialization
.
Any custom classes manipulated by a model needs to be registered through
torch.utils._pytree.register_pytree_node
or with
onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization()
and triggered by with torch_export_patches(patch_transformers=True)
.
This function does one class,
onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization()
does all known classes.
It can be undone with
onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization()
or onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization()
.
Here is the list of supported caches:
<<<
import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
print(
"\n".join(
sorted(
t.__name__
for t in p.serialization_functions(
patch_transformers=True, patch_diffusers=True
)
)
)
)
>>>
BaseModelOutput
DynamicCache
EncoderDecoderCache
HybridCache
MambaCache
SlidingWindowCache
StaticCache
UNet2DConditionOutput
Control flow rewriting¶
This is an attempt to automatically rewrite control flow using ast
.
It is implemented in onnx_diagnostic.torch_export_patches.patch_module
and
triggered with torch_export_patches(rewrite=<instance of torch.nn.Module>)
.
Option dump_rewriting=<folder>
tells the function to dump all applied
rewritings.
The following example contains the rewriting of method
transformers.models.bart.modeling_bart.BartEncoderLayer.forward()
.
The list of known rewriting to apply are returned by function
onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting()
and applied by function onnx_diagnostic.torch_export_patches.patch_module.transform_method()
.
While parsing the code, it is missing type information but this is known by
torch.export.export()
. Due to that, the automation usually needs manual tuning
to filter out some tests (argument filter_node
) or pre/post processing
(arguments pre_rewriter
, post_rewriter
) of function
onnx_diagnostic.torch_export_patches.patch_module.transform_method()
.
The main entry point is the context
onnx_diagnostic.torch_export_patches.torch_export_rewrite()
which rewrites and undoes the rewriting.
For example, the model transformers.BartForConditionalGeneration
requires the following value for parameter rewrite
:
<<<
import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
code_needing_rewriting,
)
pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))
>>>
[{'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x77e01247ad40>,
'function': <function BartEncoderLayer.forward at 0x77e0408e2980>,
'pre_rewriter': <function ast_or_into_bitor at 0x77e040874860>},
{'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x77e01247ad40>,
'function': <function PLBartEncoderLayer.forward at 0x77e0321bf380>,
'pre_rewriter': <function ast_or_into_bitor at 0x77e040874860>}]
This method has two tests. Only the first one needs to be rewritten.
The second one manipulates tuple and the automated rewritten does not handle
that because it cannot detect types. That explains why the parameter
filter_node
is filled. Then, the first test includes a condition relying on or
which must be replaced by |
. That explains the parameter pre_rewriter
.
We finally get:
--- original
+++ rewritten
@@ -26,7 +26,6 @@
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
-
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(
@@ -37,15 +36,22 @@
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
- if hidden_states.dtype == torch.float16 and (
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
- ):
+ def branch_cond_then_1(hidden_states):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+ return hidden_states.clone()
+ def branch_cond_else_1(hidden_states):
+ return hidden_states.clone()
+
+ hidden_states = torch.cond(
+ hidden_states.dtype == torch.float16
+ and torch.isinf(hidden_states).any() | torch.isnan(hidden_states).any(),
+ branch_cond_then_1,
+ branch_cond_else_1,
+ [hidden_states],
+ )
outputs = (hidden_states,)
-
if output_attentions:
- outputs += (attn_weights,)
-
+ outputs = outputs + (attn_weights,)
return outputs
The locations where it has to be done:
<<<
import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
known_transformers_rewritings_clamp_float16,
)
pprint.pprint(known_transformers_rewritings_clamp_float16())
>>>
{'AutoformerEncoder': 'AutoformerEncoderLayer',
'AutoformerEncoderLayer': 'AutoformerEncoderLayer',
'AutoformerForPrediction': 'AutoformerEncoderLayer',
'AutoformerModel': 'AutoformerEncoderLayer',
'BartEncoderLayer': 'BartEncoderLayer',
'BartForConditionalGeneration': 'BartEncoderLayer',
'BartModel': 'BartEncoderLayer',
'BigBirdPegasusForCausalLM': 'BigBirdPegasusEncoderLayer',
'BigBirdPegasusForConditionalGeneration': 'BigBirdPegasusEncoderLayer',
'BigBirdPegasusForQuestionAnswering': 'BigBirdPegasusEncoderLayer',
'BlenderbotSmallEncoderLayer': 'BlenderbotSmallEncoderLayer',
'BlenderbotSmallForCausalLM': 'BlenderbotSmallEncoderLayer',
'BlenderbotSmallForConditionalGeneration': 'BlenderbotSmallEncoderLayer',
'InformerEncoderLayer': 'InformerEncoderLayer',
'InformerForPrediction': 'InformerEncoderLayer',
'LEDClassificationHead': 'LEDEncoderLayer',
'LEDEncoderLayer': 'LEDEncoderLayer',
'LEDForConditionalGeneration': 'LEDEncoderLayer',
'MarianEncoder': 'MarianEncoderLayer',
'MarianEncoderLayer': 'MarianEncoderLayer',
'MarianMTModel': 'MarianEncoderLayer',
'MarianModel': 'MarianEncoderLayer',
'MvpEncoderLayer': 'MvpEncoderLayer',
'MvpForCausalLM': 'MvpEncoderLayer',
'MvpForConditionalGeneration': 'MvpEncoderLayer',
'MvpForQuestionAnswering': 'MvpEncoderLayer',
'MvpForSequenceClassification': 'MvpEncoderLayer',
'MvpPrompt': 'MvpEncoderLayer',
'NllbMoeEncoderLayer': 'NllbMoeEncoderLayer',
'NllbMoeForConditionalGeneration': 'NllbMoeEncoderLayer',
'PLBartEncoderLayer': 'BartEncoderLayer',
'PLBartForConditionalGeneration': 'BartEncoderLayer',
'TimeSeriesTransformerEncoderLayer': 'TimeSeriesTransformerEncoderLayer',
'TimeSeriesTransformerForPrediction': 'TimeSeriesTransformerEncoderLayer'}
<<<
import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
_rewrite_forward_clamp_float16,
)
pprint.pprint(_rewrite_forward_clamp_float16())
>>>
{'AutoformerEncoderLayer': [<class 'transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer'>],
'BartEncoderLayer': [<class 'transformers.models.bart.modeling_bart.BartEncoderLayer'>,
<class 'transformers.models.plbart.modeling_plbart.PLBartEncoderLayer'>],
'BigBirdPegasusEncoderLayer': [<class 'transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer'>],
'BlenderbotSmallEncoderLayer': [<class 'transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer'>],
'InformerEncoderLayer': [<class 'transformers.models.informer.modeling_informer.InformerEncoderLayer'>],
'LEDEncoderLayer': [<class 'transformers.models.led.modeling_led.LEDEncoderLayer'>],
'MarianEncoderLayer': [<class 'transformers.models.marian.modeling_marian.MarianEncoderLayer'>],
'MvpEncoderLayer': [<class 'transformers.models.mvp.modeling_mvp.MvpEncoderLayer'>],
'NllbMoeEncoderLayer': [<class 'transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer'>],
'TimeSeriesTransformerEncoderLayer': [<class 'transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer'>]}