onnx_diagnostic.torch_export_patches.patches.patch_transformers

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter[source][source]

Patches transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicCache[source][source]

Applies modifications implemented in PR transformers/#36652.

crop(max_length: int)[source][source]

Crop the past key values up to a new max_length in terms of tokens. max_length can also be negative to remove max_length tokens. This is used in assisted decoding and contrastive search.

classmethod from_batch_splits(splits: List[DynamicCache]) DynamicCache[source][source]

This is the opposite of the above batch_split() method. This will be used by stack_model_outputs in generation.utils

get_seq_length(layer_idx: int | None = 0) int[source][source]

Returns the sequence length of the cached states. A layer index can be optionally passed.

reorder_cache(beam_idx: LongTensor)[source][source]

Reorders the cache for beam search, given the selected beam indices.

update(key_states: Tensor, value_states: Tensor, layer_idx: int, cache_kwargs: Dict[str, Any] | None = None) Tuple[Tensor, Tensor][source][source]

Updates the cache with the new key_states and value_states for the layer layer_idx.

Parameters:
key_states (torch.Tensor):

The new key states to cache.

value_states (torch.Tensor):

The new value states to cache.

layer_idx (int):

The index of the layer to cache the states for.

cache_kwargs (Dict[str, Any], optional):

Additional arguments for the cache subclass. No additional arguments are used in DynamicCache.

Return:

A tuple containing the updated key and value states.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin[source][source]

Applies modifications implemented in PR transformers/#36652.

prepare_inputs_for_generation(input_ids: LongTensor, past_key_values: Cache | None = None, attention_mask: LongTensor | None = None, inputs_embeds: FloatTensor | None = None, cache_position: LongTensor | None = None, **kwargs)[source][source]

Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or slicing inputs given the existing cache.

See the forward pass in the model documentation for expected arguments (different models might have different requirements for e.g. past_key_values). This function should work as is for most LLMs.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention(*args, **kwargs)[source][source]
forward(hidden_states: Tensor, key_value_states: Tensor | None = None, attention_mask: Tensor | None = None, position_ids: LongTensor | None = None, past_key_value: Tuple[Tensor] | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: LongTensor | None = None, **kwargs) Tuple[Tensor, Tensor | None, Tuple[Tensor] | None][source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding(*args, **kwargs)[source][source]
forward(x, seq_len=None)[source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding(*args, **kwargs)[source][source]
forward(x, position_ids)[source][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__compute_dynamic_ntk_parameters(config: PretrainedConfig | None = None, device: device | None = None, seq_len: int | None = None, **rope_kwargs) Tuple[Tensor, float][source][source]

manual patch: [patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]

Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla

Args:
config ([~transformers.PretrainedConfig]):

The model configuration.

device (torch.device):

The device to use for initialization of the inverse frequencies.

seq_len (int, optional):

The current sequence length, used to update the dynamic RoPE at inference time.

rope_kwargs (Dict, optional):

BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.

Returns:

Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the omputed cos/sin (unused in this type of RoPE).

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) Callable[source][source]

manual patch for function transformers.masking_utils._vmap_for_bhqkv.

onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_dynamic_rope_update(rope_forward)[source][source]

manual patch: [patch:transformers.modeling_rope_utils.dynamic_rope_update]

rope_type is determined in the constructor of class transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding.

if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
    self.rope_type = config.rope_scaling.get(
        "rope_type", config.rope_scaling.get("type"))
else:
    self.rope_type = "default"

The original code of the patched function:

def dynamic_rope_update(rope_forward):
    def longrope_frequency_update(self, position_ids, device):
        seq_len = torch.max(position_ids) + 1
        if hasattr(self.config, "original_max_position_embeddings"):
            original_max_position_embeddings =
                self.config.original_max_position_embeddings
        else:
            original_max_position_embeddings =
                self.config.max_position_embeddings
        if seq_len > original_max_position_embeddings:
            if not hasattr(self, "long_inv_freq"):
                self.long_inv_freq, _ = self.rope_init_fn(
                    self.config, device, seq_len=original_max_position_embeddings + 1
                )
            self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
        else:
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)

    def dynamic_frequency_update(self, position_ids, device):
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len)
            self.register_buffer("inv_freq", inv_freq, persistent=False)
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and
                self.max_seq_len_cached > self.original_max_seq_len:
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @wraps(rope_forward)
    def wrapper(self, x, position_ids):
        if "dynamic" in self.rope_type:
            dynamic_frequency_update(self, position_ids, device=x.device)
        elif self.rope_type == "longrope":
            longrope_frequency_update(self, position_ids, device=x.device)
        return rope_forward(self, x, position_ids)

    return wrapper