yobx.torch.in_transformers._patches_model_rope_utils#
- class yobx.torch.in_transformers._patches_model_rope_utils.common_RotaryEmbedding(*args: Any, **kwargs: Any)[source]#
- forward(x, position_ids, layer_type=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- yobx.torch.in_transformers._patches_model_rope_utils.patched__compute_dynamic_ntk_parameters(config: PreTrainedConfig | None = None, device: device | None = None, seq_len: int | None = None, **rope_kwargs) Tuple[Tensor, float][source]#
manual patch:
[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
- Parameters:
config ([~transformers.PretrainedConfig]) – The model configuration.
device (torch.device) – The device to use for initialization of the inverse frequencies.
seq_len (int, optional) – The current sequence length, used to update the dynamic RoPE at inference time.
rope_kwargs (Dict, optional) – BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
- Returns:
Tuple of (torch.Tensor, float), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the omputed cos/sin (unused in this type of RoPE).
- yobx.torch.in_transformers._patches_model_rope_utils.patched_dynamic_rope_update(rope_forward)[source]#
manual patch:
[patch:transformers.modeling_rope_utils.dynamic_rope_update]rope_typeis determined in the constructor of classtransformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding.if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default"
The original code of the patched function:
def dynamic_rope_update(rope_forward): def longrope_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings else: original_max_position_embeddings = self.config.max_position_embeddings if seq_len > original_max_position_embeddings: if not hasattr(self, "long_inv_freq"): self.long_inv_freq, _ = self.rope_init_fn( self.config, device, seq_len=original_max_position_embeddings + 1 ) self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) else: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) def dynamic_frequency_update(self, position_ids, device): seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @wraps(rope_forward) def wrapper(self, x, position_ids): if "dynamic" in self.rope_type: dynamic_frequency_update(self, position_ids, device=x.device) elif self.rope_type == "longrope": longrope_frequency_update(self, position_ids, device=x.device) return rope_forward(self, x, position_ids) return wrapper