onnx_diagnostic.torch_export_patches.patches.patch_transformers¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.common_RotaryEmbedding(*args, **kwargs)[source][source]¶
- forward(x, position_ids)[source][source]¶
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter[source][source]¶
- Patches - transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask.
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicCache[source][source]¶
- Applies modifications implemented in PR transformers/#36652. - crop(max_length: int)[source][source]¶
- Crop the past key values up to a new max_length in terms of tokens. max_length can also be negative to remove max_length tokens. This is used in assisted decoding and contrastive search. 
 - classmethod from_batch_splits(splits: List[DynamicCache]) DynamicCache[source][source]¶
- This is the opposite of the above batch_split() method. This will be used by stack_model_outputs in generation.utils 
 - get_seq_length(layer_idx: int | None = 0) int[source][source]¶
- Returns the sequence length of the cached states. A layer index can be optionally passed. 
 - reorder_cache(beam_idx: LongTensor)[source][source]¶
- Reorders the cache for beam search, given the selected beam indices. 
 - update(key_states: Tensor, value_states: Tensor, layer_idx: int, cache_kwargs: Dict[str, Any] | None = None) Tuple[Tensor, Tensor][source][source]¶
- Updates the cache with the new key_states and value_states for the layer layer_idx. - Parameters:
- key_states (torch.Tensor):
- The new key states to cache. 
- value_states (torch.Tensor):
- The new value states to cache. 
- layer_idx (int):
- The index of the layer to cache the states for. 
- cache_kwargs (Dict[str, Any], optional):
- Additional arguments for the cache subclass. No additional arguments are used in DynamicCache. 
 
- Return:
- A tuple containing the updated key and value states. 
 
 
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin[source][source]¶
- Applies modifications implemented in PR transformers/#36652. - prepare_inputs_for_generation(input_ids: LongTensor, past_key_values: Cache | None = None, attention_mask: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, cache_position: LongTensor | None = None, **kwargs)[source][source]¶
- Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or slicing inputs given the existing cache. - See the forward pass in the model documentation for expected arguments (different models might have different requirements for e.g. past_key_values). This function should work as is for most LLMs. 
 
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention(*args, **kwargs)[source][source]¶
- forward(hidden_states: Tensor, key_value_states: Tensor | None = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_value: Tuple[Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: LongTensor | None = None, **kwargs) Tuple[Tensor, Tensor | None, Tuple[Tensor] | None][source][source]¶
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding(*args, **kwargs)[source][source]¶
- forward(x, seq_len=None)[source][source]¶
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
 
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding(*args, **kwargs)[source][source]¶
- class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding(*args, **kwargs)[source][source]¶
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__compute_dynamic_ntk_parameters(config: PretrainedConfig | None = None, device: device | None = None, seq_len: int | None = None, **rope_kwargs) Tuple[Tensor, float][source][source]¶
- manual patch: - [patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]- Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla - Args:
- config ([~transformers.PretrainedConfig]):
- The model configuration. 
- device (torch.device):
- The device to use for initialization of the inverse frequencies. 
- seq_len (int, optional):
- The current sequence length, used to update the dynamic RoPE at inference time. 
- rope_kwargs (Dict, optional):
- BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. 
 
- Returns:
- Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the omputed cos/sin (unused in this type of RoPE). 
 
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) Callable[source][source]¶
- manual patch for function - transformers.masking_utils._vmap_for_bhqkv.
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_dynamic_rope_update(rope_forward)[source][source]¶
- manual patch: - [patch:transformers.modeling_rope_utils.dynamic_rope_update]- rope_typeis determined in the constructor of class- transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding.- if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" - The original code of the patched function: - def dynamic_rope_update(rope_forward): def longrope_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings else: original_max_position_embeddings = self.config.max_position_embeddings if seq_len > original_max_position_embeddings: if not hasattr(self, "long_inv_freq"): self.long_inv_freq, _ = self.rope_init_fn( self.config, device, seq_len=original_max_position_embeddings + 1 ) self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) else: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) def dynamic_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @wraps(rope_forward) def wrapper(self, x, position_ids): if "dynamic" in self.rope_type: dynamic_frequency_update(self, position_ids, device=x.device) elif self.rope_type == "longrope": longrope_frequency_update(self, position_ids, device=x.device) return rope_forward(self, x, position_ids) return wrapper 
- onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_eager_mask(batch_size: int, cache_position: ~torch.Tensor, kv_length: int, kv_offset: int = 0, mask_function: ~typing.Callable = <function causal_mask_function>, attention_mask: ~torch.Tensor | None = None, dtype: ~torch.dtype = torch.float32, **kwargs) Tensor[source][source]¶
- manual patch for function - transformers.masking_utils.eager_mask.