Coverage of the Patches

Serialized Classes

The following code shows the list of serialized classes in transformers.

<<<

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print(
    "\n".join(
        sorted(
            t.__name__
            for t in p.serialization_functions(
                patch_transformers=True, patch_diffusers=True
            )
        )
    )
)

>>>

    BaseModelOutput
    DynamicCache
    EncoderDecoderCache
    MambaCache
    SlidingWindowCache
    StaticCache
    UNet2DConditionOutput

Patched Classes

The following script shows the list of methods patched for transformers.

<<<

import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p

for name, cls in p.__dict__.items():
    if name.startswith("patched_") and hasattr(cls, "_PATCHES_"):
        print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")

>>>

    AttentionMaskConverter: 
    DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
    GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
    GemmaRotaryEmbedding: forward
    Gemma2RotaryEmbedding: forward
    Gemma3RotaryEmbedding: forward
    LlamaRotaryEmbedding: forward
    MistralRotaryEmbedding: forward
    MixtralRotaryEmbedding: forward
    PhiRotaryEmbedding: forward
    Phi3RotaryEmbedding: forward
    Phi4MultimodalRotaryEmbedding: forward
    SmolLM3RotaryEmbedding: forward
    IdeficsEmbedding: forward
    IdeficsAttention: forward

Half Automated Rewrites for Control Flows

The following script shows the list of methods automatically rewritten due to control flows. The same code is duplicated in many model classes. The number of fixes if much less than the number of classes to fix.

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    known_transformers_rewritings_clamp_float16,
)

pprint.pprint(known_transformers_rewritings_clamp_float16())

>>>

    {'AutoformerEncoder': 'AutoformerEncoderLayer',
     'AutoformerEncoderLayer': 'AutoformerEncoderLayer',
     'AutoformerForPrediction': 'AutoformerEncoderLayer',
     'AutoformerModel': 'AutoformerEncoderLayer',
     'BartEncoderLayer': 'BartEncoderLayer',
     'BartForConditionalGeneration': 'BartEncoderLayer',
     'BartModel': 'BartEncoderLayer',
     'BigBirdPegasusForCausalLM': 'BigBirdPegasusEncoderLayer',
     'BigBirdPegasusForConditionalGeneration': 'BigBirdPegasusEncoderLayer',
     'BigBirdPegasusForQuestionAnswering': 'BigBirdPegasusEncoderLayer',
     'BlenderbotSmallEncoderLayer': 'BlenderbotSmallEncoderLayer',
     'BlenderbotSmallForCausalLM': 'BlenderbotSmallEncoderLayer',
     'BlenderbotSmallForConditionalGeneration': 'BlenderbotSmallEncoderLayer',
     'InformerEncoderLayer': 'InformerEncoderLayer',
     'InformerForPrediction': 'InformerEncoderLayer',
     'LEDClassificationHead': 'LEDEncoderLayer',
     'LEDEncoderLayer': 'LEDEncoderLayer',
     'LEDForConditionalGeneration': 'LEDEncoderLayer',
     'MarianEncoder': 'MarianEncoderLayer',
     'MarianEncoderLayer': 'MarianEncoderLayer',
     'MarianMTModel': 'MarianEncoderLayer',
     'MarianModel': 'MarianEncoderLayer',
     'MvpEncoderLayer': 'MvpEncoderLayer',
     'MvpForCausalLM': 'MvpEncoderLayer',
     'MvpForConditionalGeneration': 'MvpEncoderLayer',
     'MvpForQuestionAnswering': 'MvpEncoderLayer',
     'MvpForSequenceClassification': 'MvpEncoderLayer',
     'MvpPrompt': 'MvpEncoderLayer',
     'NllbMoeEncoderLayer': 'NllbMoeEncoderLayer',
     'NllbMoeForConditionalGeneration': 'NllbMoeEncoderLayer',
     'PLBartEncoderLayer': 'BartEncoderLayer',
     'PLBartForConditionalGeneration': 'BartEncoderLayer',
     'TimeSeriesTransformerEncoderLayer': 'TimeSeriesTransformerEncoderLayer',
     'TimeSeriesTransformerForPrediction': 'TimeSeriesTransformerEncoderLayer'}

<<<

import pprint
from onnx_diagnostic.torch_export_patches.patch_module_helper import (
    _rewrite_forward_clamp_float16,
)

pprint.pprint(_rewrite_forward_clamp_float16())

>>>

    {'AutoformerEncoderLayer': [<class 'transformers.models.autoformer.modeling_autoformer.AutoformerEncoderLayer'>],
     'BartEncoderLayer': [<class 'transformers.models.bart.modeling_bart.BartEncoderLayer'>,
                          <class 'transformers.models.plbart.modeling_plbart.PLBartEncoderLayer'>],
     'BigBirdPegasusEncoderLayer': [<class 'transformers.models.bigbird_pegasus.modeling_bigbird_pegasus.BigBirdPegasusEncoderLayer'>],
     'BlenderbotSmallEncoderLayer': [<class 'transformers.models.blenderbot_small.modeling_blenderbot_small.BlenderbotSmallEncoderLayer'>],
     'InformerEncoderLayer': [<class 'transformers.models.informer.modeling_informer.InformerEncoderLayer'>],
     'LEDEncoderLayer': [<class 'transformers.models.led.modeling_led.LEDEncoderLayer'>],
     'MarianEncoderLayer': [<class 'transformers.models.marian.modeling_marian.MarianEncoderLayer'>],
     'MvpEncoderLayer': [<class 'transformers.models.mvp.modeling_mvp.MvpEncoderLayer'>],
     'NllbMoeEncoderLayer': [<class 'transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeEncoderLayer'>],
     'TimeSeriesTransformerEncoderLayer': [<class 'transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoderLayer'>]}