Patches List#

torch#

<<<

import textwrap
from yobx.torch import apply_patches_for_model
from yobx.torch.in_torch.patches import get_patches

with apply_patches_for_model(patch_torch=True):
    link = []
    rows = []
    for i, patch in enumerate(get_patches()):
        name = f"patch-torch-{i+1}"
        link.append(f"* :ref:`{name}`")
        rows.extend(
            [
                "",
                f".. _{name}:",
                "",
                patch.patch.__qualname__,
                "-" * len(patch.patch.__qualname__),
                "",
                "::",
                "",
                textwrap.indent(patch.format_diff(), "    "),
                "",
            ]
        )
    print("\n".join([*link, "", *rows]))

>>>

[runpythonerror] Traceback (most recent call last):

File “<stdin>”, line 27, in <module> File “~/github/yet-another-onnx-builder/yobx/helpers/patch_helper.py”, line 225, in format_diff

diff = self.make_diff()
File “~/github/yet-another-onnx-builder/yobx/helpers/patch_helper.py”, line 187, in make_diff
assert f is not None, (

AssertionError: The patch ‘_print_Symbol’ was never applied self.function_to_patch=None, self._last_patched_function=None.

transformers#

<<<

import textwrap
from yobx.torch import apply_patches_for_model

with apply_patches_for_model(patch_transformers=True) as details:
    link = []
    rows = []
    for i, patch in enumerate(details):
        name = f"patch-transformers-{i+1}"
        link.append(f"* :ref:`{name}`")
        rows.extend(
            [
                "",
                f".. _{name}:",
                "",
                patch.patch.__qualname__,
                "-" * len(patch.patch.__qualname__),
                "",
                "::",
                "",
                textwrap.indent(patch.format_diff(), "    "),
                "",
            ]
        )
    print("\n".join([*link, "", *rows]))

>>>

common_RotaryEmbedding.forward#

transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward
--- original
+++ rewritten
@@ -1,18 +1,26 @@
-@torch.no_grad()
-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
-def forward(self, x, position_ids):
+@patched_dynamic_rope_update
+def forward(self, x, position_ids, layer_type=None):
+    if layer_type is not None:
+        # transformers>=5.0
+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
+    else:
+        # transformers<5.0
+        inv_freq = self.inv_freq
+        attention_scaling = self.attention_scaling
+
     inv_freq_expanded = (
-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
     )
     position_ids_expanded = position_ids[:, None, :].float()

     device_type = (
         x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
     )
-    with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
+    with torch.autocast(device_type=device_type, enabled=False):  # Force float32
         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
         emb = torch.cat((freqs, freqs), dim=-1)
-        cos = emb.cos() * self.attention_scaling
-        sin = emb.sin() * self.attention_scaling
+        cos = emb.cos() * attention_scaling
+        sin = emb.sin() * attention_scaling

     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

--- original
+++ rewritten
@@ -1,103 +1,193 @@
-def dynamic_rope_update(rope_forward):
-    """
-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
+def patched_dynamic_rope_update(rope_forward):
+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``

-    Args:
-        rope_forward (Callable):
-            The forward pass of the RoPE implementation.
+    ``rope_type`` is determined in the constructor of class
+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.

-    Returns:
-        The decorated forward pass.
+    .. code-block:: python
+
+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+            self.rope_type = config.rope_scaling.get(
+                "rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+
+    The original code of the patched function:
+
+    .. code-block:: python
+
+        def dynamic_rope_update(rope_forward):
+            def longrope_frequency_update(self, position_ids, device):
+                seq_len = torch.max(position_ids) + 1
+                if hasattr(self.config, "original_max_position_embeddings"):
+                    original_max_position_embeddings =
+                        self.config.original_max_position_embeddings
+                else:
+                    original_max_position_embeddings =
+                        self.config.max_position_embeddings
+                if seq_len > original_max_position_embeddings:
+                    if not hasattr(self, "long_inv_freq"):
+                        self.long_inv_freq, _ = self.rope_init_fn(
+                            self.config, device, seq_len=original_max_position_embeddings + 1
+                        )
+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
+                else:
+                    self.original_inv_freq = self.original_inv_freq.to(device)
+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+
+            def dynamic_frequency_update(self, position_ids, device):
+                seq_len = torch.max(position_ids) + 1
+                if seq_len > self.max_seq_len_cached:  # growth
+                    inv_freq, self.attention_scaling = self.rope_init_fn(
+                        self.config, device, seq_len=seq_len)
+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
+                    self.max_seq_len_cached = seq_len
+
+                if seq_len < self.original_max_seq_len and
+                        self.max_seq_len_cached > self.original_max_seq_len:
+                    self.original_inv_freq = self.original_inv_freq.to(device)
+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+                    self.max_seq_len_cached = self.original_max_seq_len
+
+            @wraps(rope_forward)
+            def wrapper(self, x, position_ids):
+                if "dynamic" in self.rope_type:
+                    dynamic_frequency_update(self, position_ids, device=x.device)
+                elif self.rope_type == "longrope":
+                    longrope_frequency_update(self, position_ids, device=x.device)
+                return rope_forward(self, x, position_ids)
+
+            return wrapper
+
     """

     def longrope_frequency_update(self, position_ids, device, layer_type=None):
-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
+        # It is no use to patch the function after the model is created
+        # as rope_init_fn is an attribute set to one function when the model
+        # is created and when no patch is applied yet.
+        # So we select the patched version here.
+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
         seq_len = torch.max(position_ids) + 1
+        if hasattr(self.config, "original_max_position_embeddings"):
+            original_max_position_embeddings = self.config.original_max_position_embeddings
+        else:
+            original_max_position_embeddings = self.config.max_position_embeddings

         if layer_type is None:
-            rope_type = self.rope_type
-            original_inv_freq = self.original_inv_freq
-            prefix = ""
-            original_max_position_embeddings = self.config.rope_parameters[
-                "original_max_position_embeddings"
-            ]
-        else:
-            rope_type = self.rope_type[layer_type]
-            original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
-            prefix = f"{layer_type}_"
-            original_max_position_embeddings = self.config.rope_parameters[layer_type][
-                "original_max_position_embeddings"
-            ]
-
-        if seq_len > original_max_position_embeddings:
-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
-                long_inv_freq, _ = rope_init_fn(
-                    self.config,
-                    device,
-                    seq_len=original_max_position_embeddings + 1,
-                    layer_type=layer_type,
-                )
-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
-        else:
-            # This .to() is needed if the model has been moved to a device after being initialized (because
-            # the buffer is automatically moved, but not the original copy)
-            original_inv_freq = original_inv_freq.to(device)
-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
-
-    def dynamic_frequency_update(self, position_ids, device, layer_type=None):
-        """
-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
-        1 - growing beyond the cached sequence length (allow scaling)
-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
-        """
-        seq_len = torch.max(position_ids) + 1
-        if layer_type is None:
-            rope_type = self.rope_type
-            max_seq_len_cached = self.max_seq_len_cached
+            # rope_type = self.rope_type
             original_inv_freq = self.original_inv_freq
             prefix = ""
         else:
-            rope_type = self.rope_type[layer_type]
-            max_seq_len_cached = getattr(
-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
-            )
+            # rope_type = self.rope_type[layer_type]
             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
             prefix = f"{layer_type}_"

-        if seq_len > max_seq_len_cached:  # growth
-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
-            inv_freq, self.attention_scaling = rope_init_fn(
-                self.config,
-                device,
-                seq_len=seq_len,
-                layer_type=layer_type,
-            )
-            # TODO joao: may break with compilation
-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
+        # At export time, seq_len is unknown.
+        long_inv_freq, _ = rope_init_fn(
+            self.config, device, seq_len=original_max_position_embeddings + 1
+        )
+        original_inv_freq = self.original_inv_freq.to(device)

-        if (
-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
-        ):  # reset
-            # This .to() is needed if the model has been moved to a device after being initialized (because
-            # the buffer is automatically moved, but not the original copy)
-            original_inv_freq = original_inv_freq.to(device)
-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
+        # PATCHED: uses torch.cond instead of a test
+        cond = (seq_len > original_max_position_embeddings).item()
+        inv_freq = torch.cond(
+            cond,
+            (lambda x, y: x.clone()),
+            (lambda x, y: y.clone()),
+            [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+        )
+        setattr(self, f"{prefix}inv_freq", inv_freq)
+        # if seq_len > original_max_position_embeddings:
+        #    self.inv_freq = self.long_inv_freq
+        # else:
+        #    self.inv_freq = self.original_inv_freq
+
+    def dynamic_frequency_update(self, position_ids, device, layer_type=None):
+        # constructor:
+        # - self.max_seq_len_cached = config.max_position_embeddings
+        # - self.original_max_seq_len = config.max_position_embeddings
+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+        # It is no use to patch the function after the model is created
+        # as rope_init_fn is an attribute set to one function when the model
+        # is created and when no patch is applied yet.
+        # So we select the patched version here.
+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
+
+        # This behaviour is difficult to translate.
+        # The sequence always grows.
+        # The test should always True.
+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
+        #
+        # if seq_len > self.max_seq_len_cached:  # growth
+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
+        #        self.config, device, seq_len=seq_len
+        #    )
+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
+        #    self.max_seq_len_cached = seq_len
+        #
+        # So we should not need what follows.
+        #
+        # cond = (seq_len > self.max_seq_len_cached).item()
+        # self.attention_scaling = torch.cond(
+        #    cond,
+        #    (lambda x, y: x.clone()),
+        #    (lambda x, y: y.clone()),
+        #    [attention_scaling, self.attention_scaling],
+        # )
+
+        seq_len = torch.max(position_ids) + 1
+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
+
+        if layer_type is None:
+            # rope_type = self.rope_type
+            # max_seq_len_cached = self.max_seq_len_cached
+            original_inv_freq = self.original_inv_freq
+            prefix = ""
+        else:
+            # rope_type = self.rope_type[layer_type]
+            # max_seq_len_cached = getattr(
+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
+            # )
+            original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
+            prefix = f"{layer_type}_"
+
+        # Second test to translate.
+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
+        # But in that case the following condition is a way to restore the original cache.
+
+        # if (
+        #    seq_len < self.original_max_seq_len
+        #    and self.max_seq_len_cached > self.original_max_seq_len
+        # ):
+        #    self.original_inv_freq = self.original_inv_freq.to(device)
+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+        #    self.max_seq_len_cached = self.original_max_seq_len
+
+        original_inv_freq = self.original_inv_freq.to(device)
+        cond = (seq_len >= self.original_max_seq_len).item()
+        # PATCHED: uses torch.cond instead of a test
+        inv_freq = torch.cond(
+            cond,
+            (lambda x, y: x.clone()),
+            (lambda x, y: y.clone()),
+            [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+        )
+        setattr(self, f"{prefix}inv_freq", inv_freq)

     @wraps(rope_forward)
     def wrapper(self, x, position_ids, layer_type=None):
-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
-        if "dynamic" in rope_type:
-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
-        elif rope_type == "longrope":
-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
-        return rope_forward(self, x, position_ids, **kwargs)
+        if layer_type is None:
+            if "dynamic" in self.rope_type:
+                dynamic_frequency_update(self, position_ids, device=x.device)
+            elif self.rope_type == "longrope":
+                longrope_frequency_update(self, position_ids, device=x.device)
+            return rope_forward(self, x, position_ids)
+
+        if "dynamic" in self.rope_type:
+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+        elif self.rope_type == "longrope":
+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+        return rope_forward(self, x, position_ids, layer_type=layer_type)

     return wrapper