onnx_diagnostic.torch_export_patches.patches.patch_transformers

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.common_RotaryEmbedding(*args, **kwargs)[source][source]
forward(x, position_ids)[source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

onnx_diagnostic.torch_export_patches.patches.patch_transformers.common_eager_attention_forward(module: Module, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float | None = None, dropout: float = 0.0, head_mask: Tensor | None = None, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter[source][source]

Patches transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin[source][source]

Applies modifications implemented in PR transformers/#36652.

prepare_inputs_for_generation(input_ids: LongTensor, past_key_values: Cache | None = None, attention_mask: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, cache_position: LongTensor | None = None, **kwargs)[source][source]

Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or slicing inputs given the existing cache.

See the forward pass in the model documentation for expected arguments (different models might have different requirements for e.g. past_key_values). This function should work as is for most LLMs.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention(*args, **kwargs)[source][source]
forward(hidden_states: Tensor, key_value_states: Tensor | None = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_value: Tuple[Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: LongTensor | None = None, **kwargs) Tuple[Tensor, Tensor | None, Tuple[Tensor] | None][source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding(*args, **kwargs)[source][source]
forward(x, seq_len=None)[source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding(*args, **kwargs)[source][source]
class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder(*args, **kwargs)[source][source]
forward(image_embeddings: Tensor, image_positional_embeddings: Tensor, sparse_prompt_embeddings: Tensor, dense_prompt_embeddings: Tensor, multimask_output: bool, output_attentions: bool | None = None, attention_similarity: Tensor | None = None, target_embedding: Tensor | None = None) tuple[Tensor, Tensor][source][source]

Predict masks given image and prompt embeddings.

Args:
image_embeddings (torch.Tensor):

the embeddings from the image encoder

image_positional_embedding (torch.Tensor):

positional encoding with the shape of image_embeddings

sparse_prompt_embeddings (torch.Tensor):

The embeddings of the points and boxes

dense_prompt_embeddings (torch.Tensor):

the embeddings of the mask inputs

multimask_output (bool):

Whether to return multiple masks or a single mask.

output_attentions (bool, optional):

Whether or not to return the attentions tensors of all attention layers.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding(*args, **kwargs)[source][source]
onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__compute_dynamic_ntk_parameters(config: PretrainedConfig | None = None, device: device | None = None, seq_len: int | None = None, **rope_kwargs) Tuple[Tensor, float][source][source]

manual patch: [patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]

Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla

Args:
config ([~transformers.PretrainedConfig]):

The model configuration.

device (torch.device):

The device to use for initialization of the inverse frequencies.

seq_len (int, optional):

The current sequence length, used to update the dynamic RoPE at inference time.

rope_kwargs (Dict, optional):

BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.

Returns:

Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the omputed cos/sin (unused in this type of RoPE).

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) Callable[source][source]

manual patch for function transformers.masking_utils._vmap_for_bhqkv.

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_dynamic_rope_update(rope_forward)[source][source]

manual patch: [patch:transformers.modeling_rope_utils.dynamic_rope_update]

rope_type is determined in the constructor of class transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding.

if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
    self.rope_type = config.rope_scaling.get(
        "rope_type", config.rope_scaling.get("type"))
else:
    self.rope_type = "default"

The original code of the patched function:

def dynamic_rope_update(rope_forward):
    def longrope_frequency_update(self, position_ids, device):
        seq_len = torch.max(position_ids) + 1
        if hasattr(self.config, "original_max_position_embeddings"):
            original_max_position_embeddings =
                self.config.original_max_position_embeddings
        else:
            original_max_position_embeddings =
                self.config.max_position_embeddings
        if seq_len > original_max_position_embeddings:
            if not hasattr(self, "long_inv_freq"):
                self.long_inv_freq, _ = self.rope_init_fn(
                    self.config, device, seq_len=original_max_position_embeddings + 1
                )
            self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
        else:
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)

    def dynamic_frequency_update(self, position_ids, device):
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len)
            self.register_buffer("inv_freq", inv_freq, persistent=False)
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and
                self.max_seq_len_cached > self.original_max_seq_len:
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @wraps(rope_forward)
    def wrapper(self, x, position_ids):
        if "dynamic" in self.rope_type:
            dynamic_frequency_update(self, position_ids, device=x.device)
        elif self.rope_type == "longrope":
            longrope_frequency_update(self, position_ids, device=x.device)
        return rope_forward(self, x, position_ids)

    return wrapper
onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_eager_mask(batch_size: int, cache_position: ~torch.Tensor, kv_length: int, kv_offset: int = 0, mask_function: ~typing.Callable = <function causal_mask_function>, attention_mask: ~torch.Tensor | None = None, dtype: ~torch.dtype = torch.float32, **kwargs) Tensor[source][source]

manual patch for function transformers.masking_utils.eager_mask.

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_model_bart_eager_attention_forward(module: Module, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float | None = None, dropout: float = 0.0, head_mask: Tensor | None = None, **kwargs)[source][source]

[patch:transformers.models.bart.modeling_bart.eager_attention_forward]

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_modeling_marian_eager_attention_forward(module: Module, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float | None = None, dropout: float = 0.0, head_mask: Tensor | None = None, **kwargs)[source][source]

[patch:transformers.models.marian.modeling_marian.eager_attention_forward]

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_parse_processor_args(processor_class: type[CacheProcessor] | None, kwargs: dict) tuple[dict, dict][source][source]

[patch:transformers.cache_utils.parse_processor_args]