Patches List#
torch#
<<<
import textwrap
from yobx.torch import apply_patches_for_model
from yobx.torch.in_torch.patches import get_patches
with apply_patches_for_model(patch_torch=True):
link = []
rows = []
for i, patch in enumerate(get_patches()):
name = f"patch-torch-{i+1}"
link.append(f"* :ref:`{name}`")
rows.extend(
[
"",
f".. _{name}:",
"",
patch.patch.__qualname__,
"-" * len(patch.patch.__qualname__),
"",
"::",
"",
textwrap.indent(patch.format_diff(), " "),
"",
]
)
print("\n".join([*link, "", *rows]))
>>>
[runpythonerror] Traceback (most recent call last):
File “<stdin>”, line 27, in <module> File “~/github/yet-another-onnx-builder/yobx/helpers/patch_helper.py”, line 225, in format_diff
- diff = self.make_diff()
- File “~/github/yet-another-onnx-builder/yobx/helpers/patch_helper.py”, line 187, in make_diff
- assert f is not None, (
AssertionError: The patch ‘_print_Symbol’ was never applied self.function_to_patch=None, self._last_patched_function=None.
transformers#
<<<
import textwrap
from yobx.torch import apply_patches_for_model
with apply_patches_for_model(patch_transformers=True) as details:
link = []
rows = []
for i, patch in enumerate(details):
name = f"patch-transformers-{i+1}"
link.append(f"* :ref:`{name}`")
rows.extend(
[
"",
f".. _{name}:",
"",
patch.patch.__qualname__,
"-" * len(patch.patch.__qualname__),
"",
"::",
"",
textwrap.indent(patch.format_diff(), " "),
"",
]
)
print("\n".join([*link, "", *rows]))
>>>
common_RotaryEmbedding.forward#
transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward
--- original
+++ rewritten
@@ -1,18 +1,26 @@
-@torch.no_grad()
-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
-def forward(self, x, position_ids):
+@patched_dynamic_rope_update
+def forward(self, x, position_ids, layer_type=None):
+ if layer_type is not None:
+ # transformers>=5.0
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
+ else:
+ # transformers<5.0
+ inv_freq = self.inv_freq
+ attention_scaling = self.attention_scaling
+
inv_freq_expanded = (
- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = (
x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
)
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
+ cos = emb.cos() * attention_scaling
+ sin = emb.sin() * attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
--- original
+++ rewritten
@@ -1,103 +1,193 @@
-def dynamic_rope_update(rope_forward):
- """
- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
+def patched_dynamic_rope_update(rope_forward):
+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
- Args:
- rope_forward (Callable):
- The forward pass of the RoPE implementation.
+ ``rope_type`` is determined in the constructor of class
+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
- Returns:
- The decorated forward pass.
+ .. code-block:: python
+
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get(
+ "rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+
+ The original code of the patched function:
+
+ .. code-block:: python
+
+ def dynamic_rope_update(rope_forward):
+ def longrope_frequency_update(self, position_ids, device):
+ seq_len = torch.max(position_ids) + 1
+ if hasattr(self.config, "original_max_position_embeddings"):
+ original_max_position_embeddings =
+ self.config.original_max_position_embeddings
+ else:
+ original_max_position_embeddings =
+ self.config.max_position_embeddings
+ if seq_len > original_max_position_embeddings:
+ if not hasattr(self, "long_inv_freq"):
+ self.long_inv_freq, _ = self.rope_init_fn(
+ self.config, device, seq_len=original_max_position_embeddings + 1
+ )
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
+ else:
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+
+ def dynamic_frequency_update(self, position_ids, device):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and
+ self.max_seq_len_cached > self.original_max_seq_len:
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @wraps(rope_forward)
+ def wrapper(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device)
+ return rope_forward(self, x, position_ids)
+
+ return wrapper
+
"""
def longrope_frequency_update(self, position_ids, device, layer_type=None):
- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
+ # It is no use to patch the function after the model is created
+ # as rope_init_fn is an attribute set to one function when the model
+ # is created and when no patch is applied yet.
+ # So we select the patched version here.
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
seq_len = torch.max(position_ids) + 1
+ if hasattr(self.config, "original_max_position_embeddings"):
+ original_max_position_embeddings = self.config.original_max_position_embeddings
+ else:
+ original_max_position_embeddings = self.config.max_position_embeddings
if layer_type is None:
- rope_type = self.rope_type
- original_inv_freq = self.original_inv_freq
- prefix = ""
- original_max_position_embeddings = self.config.rope_parameters[
- "original_max_position_embeddings"
- ]
- else:
- rope_type = self.rope_type[layer_type]
- original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
- prefix = f"{layer_type}_"
- original_max_position_embeddings = self.config.rope_parameters[layer_type][
- "original_max_position_embeddings"
- ]
-
- if seq_len > original_max_position_embeddings:
- if not hasattr(self, f"{layer_type}_long_inv_freq"):
- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
- long_inv_freq, _ = rope_init_fn(
- self.config,
- device,
- seq_len=original_max_position_embeddings + 1,
- layer_type=layer_type,
- )
- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
- else:
- # This .to() is needed if the model has been moved to a device after being initialized (because
- # the buffer is automatically moved, but not the original copy)
- original_inv_freq = original_inv_freq.to(device)
- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
-
- def dynamic_frequency_update(self, position_ids, device, layer_type=None):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if layer_type is None:
- rope_type = self.rope_type
- max_seq_len_cached = self.max_seq_len_cached
+ # rope_type = self.rope_type
original_inv_freq = self.original_inv_freq
prefix = ""
else:
- rope_type = self.rope_type[layer_type]
- max_seq_len_cached = getattr(
- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
- )
+ # rope_type = self.rope_type[layer_type]
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
prefix = f"{layer_type}_"
- if seq_len > max_seq_len_cached: # growth
- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(
- self.config,
- device,
- seq_len=seq_len,
- layer_type=layer_type,
- )
- # TODO joao: may break with compilation
- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
+ # At export time, seq_len is unknown.
+ long_inv_freq, _ = rope_init_fn(
+ self.config, device, seq_len=original_max_position_embeddings + 1
+ )
+ original_inv_freq = self.original_inv_freq.to(device)
- if (
- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
- ): # reset
- # This .to() is needed if the model has been moved to a device after being initialized (because
- # the buffer is automatically moved, but not the original copy)
- original_inv_freq = original_inv_freq.to(device)
- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
+ # PATCHED: uses torch.cond instead of a test
+ cond = (seq_len > original_max_position_embeddings).item()
+ inv_freq = torch.cond(
+ cond,
+ (lambda x, y: x.clone()),
+ (lambda x, y: y.clone()),
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+ )
+ setattr(self, f"{prefix}inv_freq", inv_freq)
+ # if seq_len > original_max_position_embeddings:
+ # self.inv_freq = self.long_inv_freq
+ # else:
+ # self.inv_freq = self.original_inv_freq
+
+ def dynamic_frequency_update(self, position_ids, device, layer_type=None):
+ # constructor:
+ # - self.max_seq_len_cached = config.max_position_embeddings
+ # - self.original_max_seq_len = config.max_position_embeddings
+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+ # It is no use to patch the function after the model is created
+ # as rope_init_fn is an attribute set to one function when the model
+ # is created and when no patch is applied yet.
+ # So we select the patched version here.
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
+
+ # This behaviour is difficult to translate.
+ # The sequence always grows.
+ # The test should always True.
+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
+ #
+ # if seq_len > self.max_seq_len_cached: # growth
+ # inv_freq, self.attention_scaling = self.rope_init_fn(
+ # self.config, device, seq_len=seq_len
+ # )
+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # self.max_seq_len_cached = seq_len
+ #
+ # So we should not need what follows.
+ #
+ # cond = (seq_len > self.max_seq_len_cached).item()
+ # self.attention_scaling = torch.cond(
+ # cond,
+ # (lambda x, y: x.clone()),
+ # (lambda x, y: y.clone()),
+ # [attention_scaling, self.attention_scaling],
+ # )
+
+ seq_len = torch.max(position_ids) + 1
+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
+
+ if layer_type is None:
+ # rope_type = self.rope_type
+ # max_seq_len_cached = self.max_seq_len_cached
+ original_inv_freq = self.original_inv_freq
+ prefix = ""
+ else:
+ # rope_type = self.rope_type[layer_type]
+ # max_seq_len_cached = getattr(
+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
+ # )
+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
+ prefix = f"{layer_type}_"
+
+ # Second test to translate.
+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
+ # But in that case the following condition is a way to restore the original cache.
+
+ # if (
+ # seq_len < self.original_max_seq_len
+ # and self.max_seq_len_cached > self.original_max_seq_len
+ # ):
+ # self.original_inv_freq = self.original_inv_freq.to(device)
+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ # self.max_seq_len_cached = self.original_max_seq_len
+
+ original_inv_freq = self.original_inv_freq.to(device)
+ cond = (seq_len >= self.original_max_seq_len).item()
+ # PATCHED: uses torch.cond instead of a test
+ inv_freq = torch.cond(
+ cond,
+ (lambda x, y: x.clone()),
+ (lambda x, y: y.clone()),
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+ )
+ setattr(self, f"{prefix}inv_freq", inv_freq)
@wraps(rope_forward)
def wrapper(self, x, position_ids, layer_type=None):
- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
- if "dynamic" in rope_type:
- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
- elif rope_type == "longrope":
- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
- return rope_forward(self, x, position_ids, **kwargs)
+ if layer_type is None:
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device)
+ return rope_forward(self, x, position_ids)
+
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+ return rope_forward(self, x, position_ids, layer_type=layer_type)
return wrapper