onnx_diagnostic.torch_export_patches.patches.patch_transformers¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.common_RotaryEmbedding(*args, **kwargs)[source][source]¶
- forward(x, position_ids)[source][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.common_eager_attention_forward(module: Module, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float | None = None, dropout: float = 0.0, head_mask: Tensor | None = None, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter[source][source]¶
Patches
transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask
.
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin[source][source]¶
Applies modifications implemented in PR transformers/#36652.
- prepare_inputs_for_generation(input_ids: LongTensor, past_key_values: Cache | None = None, attention_mask: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, cache_position: LongTensor | None = None, **kwargs)[source][source]¶
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or slicing inputs given the existing cache.
See the forward pass in the model documentation for expected arguments (different models might have different requirements for e.g. past_key_values). This function should work as is for most LLMs.
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention(*args, **kwargs)[source][source]¶
- forward(hidden_states: Tensor, key_value_states: Tensor | None = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_value: Tuple[Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: LongTensor | None = None, **kwargs) Tuple[Tensor, Tensor | None, Tuple[Tensor] | None] [source][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding(*args, **kwargs)[source][source]¶
- forward(x, seq_len=None)[source][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder(*args, **kwargs)[source][source]¶
- forward(image_embeddings: Tensor, image_positional_embeddings: Tensor, sparse_prompt_embeddings: Tensor, dense_prompt_embeddings: Tensor, multimask_output: bool, output_attentions: bool | None = None, attention_similarity: Tensor | None = None, target_embedding: Tensor | None = None) tuple[Tensor, Tensor] [source][source]¶
Predict masks given image and prompt embeddings.
- Args:
- image_embeddings (torch.Tensor):
the embeddings from the image encoder
- image_positional_embedding (torch.Tensor):
positional encoding with the shape of image_embeddings
- sparse_prompt_embeddings (torch.Tensor):
The embeddings of the points and boxes
- dense_prompt_embeddings (torch.Tensor):
the embeddings of the mask inputs
- multimask_output (bool):
Whether to return multiple masks or a single mask.
- output_attentions (bool, optional):
Whether or not to return the attentions tensors of all attention layers.
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding(*args, **kwargs)[source][source]¶
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__compute_dynamic_ntk_parameters(config: PretrainedConfig | None = None, device: device | None = None, seq_len: int | None = None, **rope_kwargs) Tuple[Tensor, float] [source][source]¶
manual patch:
[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
- Args:
- config ([~transformers.PretrainedConfig]):
The model configuration.
- device (torch.device):
The device to use for initialization of the inverse frequencies.
- seq_len (int, optional):
The current sequence length, used to update the dynamic RoPE at inference time.
- rope_kwargs (Dict, optional):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
- Returns:
Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the omputed cos/sin (unused in this type of RoPE).
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) Callable [source][source]¶
manual patch for function
transformers.masking_utils._vmap_for_bhqkv
.
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_dynamic_rope_update(rope_forward)[source][source]¶
manual patch:
[patch:transformers.modeling_rope_utils.dynamic_rope_update]
rope_type
is determined in the constructor of classtransformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
.if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default"
The original code of the patched function:
def dynamic_rope_update(rope_forward): def longrope_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings else: original_max_position_embeddings = self.config.max_position_embeddings if seq_len > original_max_position_embeddings: if not hasattr(self, "long_inv_freq"): self.long_inv_freq, _ = self.rope_init_fn( self.config, device, seq_len=original_max_position_embeddings + 1 ) self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) else: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) def dynamic_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @wraps(rope_forward) def wrapper(self, x, position_ids): if "dynamic" in self.rope_type: dynamic_frequency_update(self, position_ids, device=x.device) elif self.rope_type == "longrope": longrope_frequency_update(self, position_ids, device=x.device) return rope_forward(self, x, position_ids) return wrapper
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_eager_mask(batch_size: int, cache_position: ~torch.Tensor, kv_length: int, kv_offset: int = 0, mask_function: ~typing.Callable = <function causal_mask_function>, attention_mask: ~torch.Tensor | None = None, dtype: ~torch.dtype = torch.float32, **kwargs) Tensor [source][source]¶
manual patch for function
transformers.masking_utils.eager_mask
.
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_model_bart_eager_attention_forward(module: Module, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float | None = None, dropout: float = 0.0, head_mask: Tensor | None = None, **kwargs)[source][source]¶
[patch:transformers.models.bart.modeling_bart.eager_attention_forward]
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_modeling_marian_eager_attention_forward(module: Module, query: Tensor, key: Tensor, value: Tensor, attention_mask: Tensor | None, scaling: float | None = None, dropout: float = 0.0, head_mask: Tensor | None = None, **kwargs)[source][source]¶
[patch:transformers.models.marian.modeling_marian.eager_attention_forward]