Coverage of the Patches

Serialized Classes

The following code shows the list of serialized classes in transformers.

<<<

import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p

print("\n".join(sorted(p.serialization_functions())))

>>>

    BaseModelOutput
    DynamicCache
    EncoderDecoderCache
    MambaCache
    SlidingWindowCache

Patched Classes

The following script shows the list of method patched for transformers.

<<<

import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p

for name, cls in p.__dict__.items():
    if name.startswith("patched_"):
        print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}")

>>>

    AttentionMaskConverter: 
    DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
    GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation

Half Automated Rewrites

The following script shows the list of method patched for transformers.

<<<

import onnx_diagnostic.torch_export_patches.patch_module_helper as p

for name, f in p.__dict__.items():
    if name.startswith("_rewrite_"):
        print(f.__doc__)

>>>

    BartEncoderLayer, PLBartEncoderLayer