yobx.torch.in_transformers._patches_model_rope_utils#

class yobx.torch.in_transformers._patches_model_rope_utils.common_RotaryEmbedding(*args: Any, **kwargs: Any)[source]#
forward(x, position_ids, layer_type=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

yobx.torch.in_transformers._patches_model_rope_utils.patched__compute_dynamic_ntk_parameters(config: PreTrainedConfig | None = None, device: device | None = None, seq_len: int | None = None, **rope_kwargs) Tuple[Tensor, float][source]#

manual patch: [patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]

Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla

Parameters:
  • config ([~transformers.PretrainedConfig]) – The model configuration.

  • device (torch.device) – The device to use for initialization of the inverse frequencies.

  • seq_len (int, optional) – The current sequence length, used to update the dynamic RoPE at inference time.

  • rope_kwargs (Dict, optional) – BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.

Returns:

Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the omputed cos/sin (unused in this type of RoPE).

yobx.torch.in_transformers._patches_model_rope_utils.patched_dynamic_rope_update(rope_forward)[source]#

manual patch: [patch:transformers.modeling_rope_utils.dynamic_rope_update]

rope_type is determined in the constructor of class transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding.

if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
    self.rope_type = config.rope_scaling.get(
        "rope_type", config.rope_scaling.get("type"))
else:
    self.rope_type = "default"

The original code of the patched function:

def dynamic_rope_update(rope_forward):
    def longrope_frequency_update(self, position_ids, device):
        seq_len = torch.max(position_ids) + 1
        if hasattr(self.config, "original_max_position_embeddings"):
            original_max_position_embeddings =
                self.config.original_max_position_embeddings
        else:
            original_max_position_embeddings =
                self.config.max_position_embeddings
        if seq_len > original_max_position_embeddings:
            if not hasattr(self, "long_inv_freq"):
                self.long_inv_freq, _ = self.rope_init_fn(
                    self.config, device, seq_len=original_max_position_embeddings + 1
                )
            self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
        else:
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)

    def dynamic_frequency_update(self, position_ids, device):
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len)
            self.register_buffer("inv_freq", inv_freq, persistent=False)
            self.max_seq_len_cached = seq_len

        if seq_len < self.original_max_seq_len and
                self.max_seq_len_cached > self.original_max_seq_len:
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @wraps(rope_forward)
    def wrapper(self, x, position_ids):
        if "dynamic" in self.rope_type:
            dynamic_frequency_update(self, position_ids, device=x.device)
        elif self.rope_type == "longrope":
            longrope_frequency_update(self, position_ids, device=x.device)
        return rope_forward(self, x, position_ids)

    return wrapper