Patches Explained

Function onnx_diagnostic.torch_export_patches.torch_export_patches() implements four kinds of patches to make it easier to export a model, usually coming from transformers. All patches takes place in onnx_diagnostic.torch_export_patches.

Four Kinds of Patches

with torch_export_patches(...) as f:
    ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
  1. torch fixes: it disables some exceptions or improves some functions related to dynamic shapes until torch addresses the issues (see mostly exporter issues)

  2. transformers rewriting: some methods are replaced with a version torch.export.export() can understand, some rewriting may migrate to transformers, others are applied only at export time because it would make the implementation less efficient

  3. cache serialization: torch.export.export() needs to know how to serialize custom classes such as transformers.cache_utils.DynamicCache

  4. control flow rewriting: control flow (if, for) cannot be exported as is, there is still some work to be done to automatically process them, this package offers some automated rewriting, but it is far from being perfect.

All of them are triggered by onnx_diagnostic.torch_export_patches.torch_export_patches().

python -m onnx_diagnostic validate \
    -m hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration \
    --run -v 1 --export onnx-dynamo -o dump_test --dtype float16 --device cuda

All patches can be disabled with with torch_export_patches(patch=False).

torch fixes

Implemented in onnx_diagnostic.torch_export_patches.patches.patch_torch and triggered with with torch_export_patches(patch_sympy=True, patch_torch=True, catch_constraints=True, stop_if_static=1...).

It fixes some issues found while exporting model. Some of them might not be needed anymore. It improves shape broadcasting or inserts an exception every time a dynamic dimension becomes static (stop_if_static=1).

transformers rewriting

Implemented in onnx_diagnostic.torch_export_patches.patches.patch_transformers and triggered with with torch_export_patches(patch_transformers=True).

Every patched class is prefixed with patched_. It contains two class attributes. _PATCHES_ contains the list of methods to replace. _PATCHED_CLASS_ is the class patched by this one.

class patched_AttentionMaskConverter:
    """
    Patches
    ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
    """

    # This method was fixed in 4.51 at least.
    _PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else []
    _PATCHED_CLASS_ = AttentionMaskConverter

The packages automatically parses this file to extract the patched methods. More can be added by populating the argument custom_patches: with torch_export_patches(patch_transformers=True, custom_patches=[...]). Here is the list of available patches:

<<<

import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p

for name, cls in p.__dict__.items():
    if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
        print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")

>>>

    AttentionMaskConverter: 
    DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
    GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
    Phi3RotaryEmbedding: forward

Cache serialization

Implemented in onnx_diagnostic.torch_export_patches.onnx_export_serialization. Any custom classes manipulated by a model needs to be registered through torch.utils._pytree.register_pytree_node or with onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization() and triggered by with torch_export_patches(patch_transformers=True). This function does one class, onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization() does all known classes. It can be undone with onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister() or onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization(). Here is the list of supported caches:

<<<

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print("\n".join(sorted(p.serialization_functions())))

>>>

    BaseModelOutput
    DynamicCache
    EncoderDecoderCache
    MambaCache
    SlidingWindowCache

Control flow rewriting

This is an attempt to automatically rewrite control flow using ast. It is implemented in onnx_diagnostic.torch_export_patches.patch_module and triggered with torch_export_patches(rewrite=<instance of torch.nn.Module>). Option dump_rewriting=<folder> tells the function to dump all applied rewritings.

The following example contains the rewriting of method transformers.models.bart.modeling_bart.BartEncoderLayer.forward(). The list of known rewriting to apply are returned by function onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting() and applied by function onnx_diagnostic.torch_export_patches.patch_module.transform_method().

While parsing the code, it is missing type information but this is known by torch.export.export(). Due to that, the automation usually needs manual tuning to filter out some tests (argument filter_node) or pre/post processing (arguments pre_rewriter, post_rewriter) of function onnx_diagnostic.torch_export_patches.patch_module.transform_method().

The main entry point is the context onnx_diagnostic.torch_export_patches.torch_export_rewrite() which rewrites and undoes the rewriting. For example, the model transformers.BartForConditionalGeneration requires the following value for parameter rewrite:

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    code_needing_rewriting,
)

pprint.pprint(code_needing_rewriting("BartForConditionalGeneration"))

>>>

    [{'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x7e2bbf119ee0>,
      'function': <function BartEncoderLayer.forward at 0x7e2bbfd677e0>,
      'pre_rewriter': <function ast_or_into_bitor at 0x7e2bc0718900>},
     {'filter_node': <function rewritings_transformers_clamp_float16.<locals>.<lambda> at 0x7e2bbf119ee0>,
      'function': <function PLBartEncoderLayer.forward at 0x7e2bbfd72ac0>,
      'pre_rewriter': <function ast_or_into_bitor at 0x7e2bc0718900>}]

This method has two tests. Only the first one needs to be rewritten. The second one manipulates tuple and the automated rewritten does not handle that because it cannot detect types. That explains why the parameter filter_node is filled. Then, the first test includes a condition relying on or which must be replaced by |. That explains the parameter pre_rewriter. We finally get:

--- original
+++ rewritten
@@ -26,7 +26,6 @@
    hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
    hidden_states = residual + hidden_states
    hidden_states = self.self_attn_layer_norm(hidden_states)
-
    residual = hidden_states
    hidden_states = self.activation_fn(self.fc1(hidden_states))
    hidden_states = nn.functional.dropout(
@@ -37,15 +36,22 @@
    hidden_states = residual + hidden_states
    hidden_states = self.final_layer_norm(hidden_states)

-    if hidden_states.dtype == torch.float16 and (
-        torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
-    ):
+    def branch_cond_then_1(hidden_states):
        clamp_value = torch.finfo(hidden_states.dtype).max - 1000
        hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+        return hidden_states.clone()

+    def branch_cond_else_1(hidden_states):
+        return hidden_states.clone()
+
+    hidden_states = torch.cond(
+        hidden_states.dtype == torch.float16
+        and torch.isinf(hidden_states).any() | torch.isnan(hidden_states).any(),
+        branch_cond_then_1,
+        branch_cond_else_1,
+        [hidden_states],
+    )
    outputs = (hidden_states,)
-
    if output_attentions:
-        outputs += (attn_weights,)
-
+        outputs = outputs + (attn_weights,)
    return outputs

The locations where it has to be done:

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    known_transformers_rewritings_clamp_float16,
)

pprint.pprint(known_transformers_rewritings_clamp_float16())

>>>

    {'AutoformerEncoder': 'AutoformerEncoderLayer',
     'AutoformerEncoderLayer': 'AutoformerEncoderLayer',
     'AutoformerForPrediction': 'AutoformerEncoderLayer',
     'AutoformerModel': 'AutoformerEncoderLayer',
     'BartEncoderLayer': 'BartEncoderLayer',
     'BartForConditionalGeneration': 'BartEncoderLayer',
     'BigBirdPegasusForCausalLM': 'BigBirdPegasusEncoderLayer',
     'BigBirdPegasusForConditionalGeneration': 'BigBirdPegasusEncoderLayer',
     'BigBirdPegasusForQuestionAnswering': 'BigBirdPegasusEncoderLayer',
     'BlenderbotSmallEncoderLayer': 'BlenderbotSmallEncoderLayer',
     'BlenderbotSmallForCausalLM': 'BlenderbotSmallEncoderLayer',
     'BlenderbotSmallForConditionalGeneration': 'BlenderbotSmallEncoderLayer',
     'InformerEncoderLayer': 'InformerEncoderLayer',
     'InformerForPrediction': 'InformerEncoderLayer',
     'LEDClassificationHead': 'LEDEncoderLayer',
     'LEDEncoderLayer': 'LEDEncoderLayer',
     'LEDForConditionalGeneration': 'LEDEncoderLayer',
     'MarianEncoder': 'MarianEncoderLayer',
     'MarianEncoderLayer': 'MarianEncoderLayer',
     'MarianMTModel': 'MarianEncoderLayer',
     'MarianModel': 'MarianEncoderLayer',
     'MvpEncoderLayer': 'MvpEncoderLayer',
     'MvpForCausalLM': 'MvpEncoderLayer',
     'MvpForConditionalGeneration': 'MvpEncoderLayer',
     'MvpForQuestionAnswering': 'MvpEncoderLayer',
     'MvpForSequenceClassification': 'MvpEncoderLayer',
     'MvpPrompt': 'MvpEncoderLayer',
     'NllbMoeEncoderLayer': 'NllbMoeEncoderLayer',
     'NllbMoeForConditionalGeneration': 'NllbMoeEncoderLayer',
     'PLBartEncoderLayer': 'BartEncoderLayer',
     'PLBartForConditionalGeneration': 'BartEncoderLayer',
     'TimeSeriesTransformerEncoderLayer': 'TimeSeriesTransformerEncoderLayer',
     'TimeSeriesTransformerForPrediction': 'TimeSeriesTransformerEncoderLayer'}

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    _rewrite_forward_clamp_float16,
)

pprint.pprint(_rewrite_forward_clamp_float16())

>>>

    {'AutoformerEncoderLayer': [<class 'transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer'>],
     'BartEncoderLayer': [<class 'transformers.models.bart.modeling_bart.BartEncoderLayer'>,
                          <class 'transformers.models.plbart.modeling_plbart.PLBartEncoderLayer'>],
     'BigBirdPegasusEncoderLayer': [<class 'transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer'>],
     'BlenderbotSmallEncoderLayer': [<class 'transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer'>],
     'InformerEncoderLayer': [<class 'transformers.models.informer.modeling_informer.InformerEncoderLayer'>],
     'LEDEncoderLayer': [<class 'transformers.models.led.modeling_led.LEDEncoderLayer'>],
     'MarianEncoderLayer': [<class 'transformers.models.marian.modeling_marian.MarianEncoderLayer'>],
     'MvpEncoderLayer': [<class 'transformers.models.mvp.modeling_mvp.MvpEncoderLayer'>],
     'NllbMoeEncoderLayer': [<class 'transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer'>],
     'TimeSeriesTransformerEncoderLayer': [<class 'transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer'>]}