Patches Diff

Patches are not always needed to export a LLM. Most of the time, only serialization function are needed to export a LLM with cache (DynamicCache, …). Function register_additional_serialization_functions is enough in many cases.

import torch
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions

with register_additional_serialization_functions(patch_transformers=True):
    ep = torch.export.export(...)

Function torch_export_patches helps fixing some issues for many models.

import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches

with torch_export_patches(patch_transformers=True):
    ep = torch.export.export(...)

Class PatchDetails gives an example on how to retrieve the list of involded patches for a specific model. Those patches belongs to the following list which depends on transformers and pytorch versions.

<<<

import torch
import transformers

print(torch.__version__, transformers.__version__)

>>>

    2.10.0.dev20251022+cu130 5.0.0.dev0

Those two versions leads to the following list of patches.

<<<

from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_export_patches import torch_export_patches

details = PatchDetails()
with torch_export_patches(
    patch_transformers=True,
    patch_torch=True,
    patch_diffusers=True,
    patch_details=details,
):
    pass
for patch in details.patched:
    if patch.function_to_patch == patch.patch:
        continue
    rst = patch.format_diff(format="rst")
    print()
    print()
    print(rst)
    print()
    print()

>>>

sympy: ‘sympy.core.numbers.IntegerConstant.name’ -> _patch_sympy.<locals>.<lambda>

1sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"

torch: infer_size -> patched_infer_size

 1--- original
 2+++ rewritten
 3@@ -1,4 +1,5 @@
 4-def infer_size(a, b):
 5+def patched_infer_size(a, b):
 6+    """Patches ``torch._subclasses.fake_impls.infer_size``."""
 7     from torch.fx.experimental.symbolic_shapes import guard_or_false
 8
 9     dimsA = len(a)
10@@ -23,11 +24,21 @@
11         # expression of an or statement as-is, without bool()'ing it; if this
12         # were not the case, we'd need to write this using torch.sym_or() or
13         # something like that).
14-        torch._check(
15-            guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
16-            lambda: f"The size of tensor a ({sizeA}) "
17-            f"must match the size of tensor b ({sizeB}) "
18-            f"at non-singleton dimension {i})",
19-        )
20-        expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
21+        try:
22+            b1 = guard_or_false(sizeA == 1)
23+        except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
24+            b1 = False
25+        try:
26+            b2 = guard_or_false(sizeB == 1)
27+        except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
28+            b2 = False
29+        try:
30+            b3 = guard_or_false(sizeA == sizeB)
31+        except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
32+            b3 = False
33+        if b1 or b2 or b3:
34+            expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
35+        else:
36+            # PATCHED: generic case, the dimension is known, no need to assert
37+            expandedSizes[i] = torch.sym_max(sizeA, sizeB)
38     return tuple(expandedSizes)

torch: _broadcast_shapes -> patched__broadcast_shapes

 1--- original
 2+++ rewritten
 3@@ -1,5 +1,11 @@
 4-def _broadcast_shapes(*_shapes):
 5-    from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
 6+def patched__broadcast_shapes(*_shapes):
 7+    """Patches ``torch._refs._broadcast_shapes``."""
 8+    from functools import reduce
 9+    from torch._prims_common import IntLike
10+    from torch.fx.experimental.symbolic_shapes import (
11+        guard_or_false,
12+        is_nested_int,
13+    )
14
15     shapes = tuple(
16         (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
17@@ -12,17 +18,15 @@
18     for shape in shapes:
19         if not isinstance(shape, Sequence):
20             raise RuntimeError(
21-                "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
22+                "Input shapes should be of type ints, a tuple of ints, "
23+                "or a list of ints, got ",
24                 shape,
25             )
26
27     # Computes common shape
28-    common_shape: list[Union[int, torch.SymInt]] = [
29-        1,
30-    ] * reduce(max, (len(shape) for shape in shapes))
31-    for arg_idx, shape in enumerate(shapes):
32+    common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
33+    for _arg_idx, shape in enumerate(shapes):
34         for idx in range(-1, -1 - len(shape), -1):
35-            # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
36             if is_nested_int(shape[idx]):
37                 # Broadcasting is allowed for (j0, 1) or (j0, j0);
38                 # not (j0, j1), (j0, 5), etc.
39@@ -33,22 +37,15 @@
40             else:
41                 if guard_or_false(shape[idx] == common_shape[idx]):
42                     continue
43-
44-            if guard_or_false(common_shape[idx] == 1):
45+            # PATCHED: two cases, if == for sure, no broadcast,
46+            # otherwise maybe broadcast with max(dimensions)
47+            if guard_or_false(common_shape[idx] != 1):
48+                pass
49+            elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
50                 if shape[idx] < 0:
51                     raise ValueError("Attempting to broadcast a dimension with negative length!")
52                 common_shape[idx] = shape[idx]
53-
54-            if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
55-                # broadcast case .
56-                continue
57             else:
58-                # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
59-                torch._check(
60-                    common_shape[idx] == shape[idx],
61-                    lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
62-                    f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
63-                    f"should be broadcastable to {common_shape}",
64-                )
65+                common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
66
67     return common_shape

torch: _constrain_user_specified_dimhint_range -> patched__constrain_user_specified_dimhint_range

 1--- original
 2+++ rewritten
 3@@ -1,28 +1,31 @@
 4-def _constrain_user_specified_dimhint_range(
 5+def patched__constrain_user_specified_dimhint_range(
 6     symint: torch.SymInt,
 7     hint: int,
 8-    dim: _DimHint,
 9+    dim: "_DimHint",  # noqa: F821
10     range_constraints,
11     shape_env,
12-    keypath: KeyPath,
13+    keypath: "KeyPath",  # noqa: F821
14     i: Optional[int] = None,
15 ) -> Optional[str]:
16+    """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
17+    from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
18+
19     trace_vr = (
20         range_constraints[symint.node.expr]
21         if not is_int(symint)
22         else ValueRanges(int(symint), int(symint))
23     )
24-
25     # warn on 0/1 specialization for Dim.AUTO; not an actual error
26-    if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
27-        pathstr = f"inputs{pytree.keystr(keypath)}"
28-        if i is not None:
29-            pathstr += f".shape[{i}]"
30-        msg = (
31-            f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
32-            + f"with a sample input with hint = {hint}."
33-        )
34-        log.warning(msg)
35+    # PATCHED: remove logging
36+    # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
37+    #    pathstr = f"inputs{pytree.keystr(keypath)}"
38+    #    if i is not None:
39+    #        pathstr += f".shape[{i}]"
40+    #    msg = (
41+    #        f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
42+    #        f"with a sample input with hint = {hint}."
43+    #    )
44+    #    log.warning(msg)
45
46     try:
47         user_vr = ValueRanges(
48@@ -38,32 +41,40 @@
49
50         # check for Dim.DYNAMIC specializations; special case error message on 0/1
51         if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
52-            path = f"inputs{pytree.keystr(keypath)}"
53+            path = f"inputs{torch.utils._pytree.keystr(keypath)}"
54             if i is not None:
55                 path += f".shape[{i}]"
56             if (
57                 trace_vr.is_singleton()
58                 and hint in (0, 1)
59-                and not torch.fx.experimental._config.backed_size_oblivious
60+                # PATCHED: line removed
61+                # and not torch.fx.experimental._config.backed_size_oblivious
62             ):
63-                msg = (
64-                    f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
65-                    f"but export 0/1 specialized due to hint of {hint} for dimension {path}."
66-                )
67+                return None
68+                # PATCHED: line removed
69+                # msg = (
70+                #     f"- Received user-specified dim hint "
71+                #     f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
72+                #     f"but export 0/1 specialized due to hint of "
73+                #     f"{hint} for dimension {path}."
74+                # )
75             else:
76                 msg = (
77-                    f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
78-                    f"but tracing inferred a static shape of {out_vr.lower} for dimension {path}."
79+                    f"- Received user-specified dim hint "
80+                    f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
81+                    f"but tracing inferred a static shape of "
82+                    f"{out_vr.lower} for dimension {path}."
83                 )
84             return msg
85
86     except torch.utils._sympy.value_ranges.ValueRangeError:
87-        path = f"inputs{pytree.keystr(keypath)}"
88+        path = f"inputs{torch.utils._pytree.keystr(keypath)}"
89         if i is not None:
90             path += f".shape[{i}]"
91         msg = (
92             f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
93-            f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], "
94+            f"conflicting with the inferred min/max range of "
95+            f"[{trace_vr.lower}, {trace_vr.upper}], "
96             f"for {path}."
97         )
98         return msg

torch: _broadcast_in_dim_meta -> patched__broadcast_in_dim_meta

 1--- original
 2+++ rewritten
 3@@ -1,6 +1,9 @@
 4-def _broadcast_in_dim_meta(
 5-    a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
 6+def patched__broadcast_in_dim_meta(
 7+    a: torch._prims_common.TensorLikeType,
 8+    shape: torch._prims_common.ShapeType,
 9+    broadcast_dimensions: Sequence[int],
10 ):
11+    """Patches ``torch._prims._broadcast_in_dim_meta``."""
12     from torch.fx.experimental.symbolic_shapes import (
13         guard_or_false,
14         guard_or_true,
15@@ -8,7 +11,7 @@
16     )
17
18     # Type checks
19-    assert isinstance(a, TensorLike)
20+    assert isinstance(a, torch._prims_common.TensorLike)
21     assert isinstance(shape, Sequence)
22     assert isinstance(broadcast_dimensions, Sequence)
23
24@@ -22,7 +25,7 @@
25     # (no relative reordering of dims) of integers and
26     # each dimension must be within the new shape
27     def _greater_than_reduce(acc, x):
28-        assert isinstance(x, Dim)
29+        assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
30         assert x > acc
31         assert x < len(shape)
32
33@@ -34,7 +37,9 @@
34     for idx, new_idx in enumerate(broadcast_dimensions):
35         torch._check(
36             sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
37-            lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
38+            lambda idx=idx, new_idx=new_idx: (
39+                f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
40+            ),
41         )
42
43     new_strides = []
44@@ -48,10 +53,26 @@
45                     new_strides.append(a.stride()[original_idx])
46                 else:
47                     new_strides.append(0)
48+            # PATCHED: disabled this check
49+            elif guard_or_false(a.shape[original_idx] != 1):
50+                new_strides.append(a.stride()[original_idx])
51             else:
52+                # This checks generates the following issue:
53+                # non-broadcasting semantics require s3 == Max(s10, s3), False,
54+                # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
55+                # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
56+                # original_idx=1
57                 torch._check(
58                     a.shape[original_idx] == shape[idx],
59-                    lambda: f"non-broadcasting semantics require {a.shape[original_idx]} == {shape[idx]}",
60+                    lambda idx=idx, original_idx=original_idx: (
61+                        f"non-broadcasting semantics require "
62+                        f"{a.shape[original_idx]} == {shape[idx]}, "
63+                        f"{guard_or_false(a.shape[idx] != 1)}, "
64+                        f"guard_or_false(a.shape[idx]==1)="
65+                        f"{guard_or_false(a.shape[idx] == 1)}, "
66+                        f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
67+                        f"shape={shape}, original_idx={original_idx}"
68+                    ),
69                 )
70                 new_strides.append(a.stride()[original_idx])
71             original_idx = original_idx + 1

torch: _maybe_broadcast -> patched__maybe_broadcast

 1--- original
 2+++ rewritten
 3@@ -1,6 +1,9 @@
 4-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
 5+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
 6+    """Patches ``torch._refs._maybe_broadcast``."""
 7+    from torch._prims_common import ShapeType, TensorLike, Number
 8+
 9     # Computes common shape
10-    common_shape = _broadcast_shapes(
11+    common_shape = patched__broadcast_shapes(
12         *(t.shape if isinstance(t, TensorLike) else None for t in args)
13     )
14
15@@ -29,10 +32,15 @@
16                 return True
17
18             # u0==u1 assume the same, no broadcasting!
19-            torch._check(
20-                x == y,
21-                lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
22-            )
23+            # PATCHED: avoid errors
24+            return True  # guard_or_true(x != y)
25+            # torch._check(
26+            #    x == y,
27+            #    lambda x=x, y=y: (
28+            #        f"sizes assumed to be the same due to unbacked "
29+            #        f"broadcasting semantics x={x!r}, y={y!r}"
30+            #    ),
31+            # )
32
33         return False
34
35@@ -42,7 +50,7 @@
36         elif isinstance(x, Number):
37             return x
38         elif isinstance(x, TensorLike):
39-            if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
40+            if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
41                 return x
42
43             if should_expand(x.shape, common_shape):
44@@ -50,6 +58,6 @@
45
46             return x
47         else:
48-            raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
49+            raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
50
51     return tuple(__maybe_broadcast(x, common_shape) for x in args)

torch: ShapeEnv._evaluate_expr -> patched_ShapeEnv._evaluate_expr

 1--- original
 2+++ rewritten
 3@@ -1,14 +1,24 @@
 4 def _evaluate_expr(
 5     self,
 6-    orig_expr: sympy.Basic,
 7+    orig_expr: "sympy.Basic",  # noqa: F821
 8     hint: Optional[Union[bool, int, float]] = None,
 9     fx_node: Optional[torch.fx.Node] = None,
10     size_oblivious: bool = False,
11     fallback_value: Optional[bool] = None,
12     *,
13     forcing_spec: bool = False,
14-) -> sympy.Basic:
15+) -> "sympy.Basic":  # noqa: F821
16     # TODO: split conjunctions and evaluate them separately
17+    import sympy
18+    from torch.fx.experimental import _config as config
19+    from torch.fx.experimental.symbolic_shapes import (
20+        SympyBoolean,
21+        log,
22+        SymT,
23+        symbol_is_type,
24+    )
25+    from torch._guards import ShapeGuard
26+
27     if isinstance(
28         orig_expr,
29         (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
30@@ -118,7 +128,8 @@
31                     self._log_suppressed_dde(orig_expr, fallback_value)
32                     return fallback_value
33
34-                # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
35+                # oblivious_var_to_val will be defined iff we have sizes
36+                # with DimDynamic.OBLIVIOUS_SIZE type.
37                 # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
38                 if (
39                     self.oblivious_var_to_val
40@@ -145,7 +156,8 @@
41                     ok = True
42
43                 # unbacked_var_to_val is not None iff propagate_real_tensors is on.
44-                # if propagate_real_tensors is on, we check the example values to generate (unsound_result)
45+                # if propagate_real_tensors is on, we check the example values
46+                # to generate (unsound_result)
47                 # and if they pass we add a runtime assertions and continue.
48                 if (
49                     not ok
50@@ -163,19 +175,22 @@
51                     concrete_val = unsound_result
52                     ok = True
53
54-                # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
55+                # Check if this is coming from a python assert statement,
56+                # if so, convert it to a runtime assertion
57                 # instead of failing.
58                 if not ok and self.trace_asserts and self._is_python_assert():
59                     concrete_val = sympy.true
60                     transmute_into_runtime_assert = True
61                     ok = True
62
63-                if not ok:
64-                    raise self._make_data_dependent_error(
65-                        expr.xreplace(self.var_to_val),
66-                        expr,
67-                        expr_sym_node_id=self._expr_sym_node_id,
68-                    )
69+                # PATCHED: ok -> True
70+                ok = True
71+                # if not ok:
72+                #    raise self._make_data_dependent_error(
73+                #        expr.xreplace(self.var_to_val),
74+                #        expr,
75+                #        expr_sym_node_id=self._expr_sym_node_id,
76+                #    )
77             else:
78                 expr = new_expr

patch_transformers: dynamic_rope_update -> patched_dynamic_rope_update

  1--- original
  2+++ rewritten
  3@@ -1,99 +1,193 @@
  4-def dynamic_rope_update(rope_forward):
  5-    """
  6-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
  7-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
  8+def patched_dynamic_rope_update(rope_forward):
  9+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 10
 11-    Args:
 12-        rope_forward (Callable):
 13-            The forward pass of the RoPE implementation.
 14+    ``rope_type`` is determined in the constructor of class
 15+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 16
 17-    Returns:
 18-        The decorated forward pass.
 19+    .. code-block:: python
 20+
 21+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 22+            self.rope_type = config.rope_scaling.get(
 23+                "rope_type", config.rope_scaling.get("type"))
 24+        else:
 25+            self.rope_type = "default"
 26+
 27+    The original code of the patched function:
 28+
 29+    .. code-block:: python
 30+
 31+        def dynamic_rope_update(rope_forward):
 32+            def longrope_frequency_update(self, position_ids, device):
 33+                seq_len = torch.max(position_ids) + 1
 34+                if hasattr(self.config, "original_max_position_embeddings"):
 35+                    original_max_position_embeddings =
 36+                        self.config.original_max_position_embeddings
 37+                else:
 38+                    original_max_position_embeddings =
 39+                        self.config.max_position_embeddings
 40+                if seq_len > original_max_position_embeddings:
 41+                    if not hasattr(self, "long_inv_freq"):
 42+                        self.long_inv_freq, _ = self.rope_init_fn(
 43+                            self.config, device, seq_len=original_max_position_embeddings + 1
 44+                        )
 45+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 46+                else:
 47+                    self.original_inv_freq = self.original_inv_freq.to(device)
 48+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 49+
 50+            def dynamic_frequency_update(self, position_ids, device):
 51+                seq_len = torch.max(position_ids) + 1
 52+                if seq_len > self.max_seq_len_cached:  # growth
 53+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 54+                        self.config, device, seq_len=seq_len)
 55+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 56+                    self.max_seq_len_cached = seq_len
 57+
 58+                if seq_len < self.original_max_seq_len and
 59+                        self.max_seq_len_cached > self.original_max_seq_len:
 60+                    self.original_inv_freq = self.original_inv_freq.to(device)
 61+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 62+                    self.max_seq_len_cached = self.original_max_seq_len
 63+
 64+            @wraps(rope_forward)
 65+            def wrapper(self, x, position_ids):
 66+                if "dynamic" in self.rope_type:
 67+                    dynamic_frequency_update(self, position_ids, device=x.device)
 68+                elif self.rope_type == "longrope":
 69+                    longrope_frequency_update(self, position_ids, device=x.device)
 70+                return rope_forward(self, x, position_ids)
 71+
 72+            return wrapper
 73+
 74     """
 75
 76     def longrope_frequency_update(self, position_ids, device, layer_type=None):
 77-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
 78+        # It is no use to patch the function after the model is created
 79+        # as rope_init_fn is an attribute set to one function when the model
 80+        # is created and when no patch is applied yet.
 81+        # So we select the patched version here.
 82+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
 83         seq_len = torch.max(position_ids) + 1
 84-        original_max_position_embeddings = getattr(
 85-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
 86-        )
 87+        if hasattr(self.config, "original_max_position_embeddings"):
 88+            original_max_position_embeddings = self.config.original_max_position_embeddings
 89+        else:
 90+            original_max_position_embeddings = self.config.max_position_embeddings
 91+
 92         if layer_type is None:
 93-            rope_type = self.rope_type
 94+            # rope_type = self.rope_type
 95             original_inv_freq = self.original_inv_freq
 96             prefix = ""
 97         else:
 98-            rope_type = self.rope_type[layer_type]
 99+            # rope_type = self.rope_type[layer_type]
100             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
101             prefix = f"{layer_type}_"
102
103-        if seq_len > original_max_position_embeddings:
104-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
105-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
106-                long_inv_freq, _ = rope_init_fn(
107-                    self.config,
108-                    device,
109-                    seq_len=original_max_position_embeddings + 1,
110-                    layer_type=layer_type,
111-                )
112-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
113-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
114-        else:
115-            # This .to() is needed if the model has been moved to a device after being initialized (because
116-            # the buffer is automatically moved, but not the original copy)
117-            original_inv_freq = original_inv_freq.to(device)
118-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
119-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
120+        # At export time, seq_len is unknown.
121+        long_inv_freq, _ = rope_init_fn(
122+            self.config, device, seq_len=original_max_position_embeddings + 1
123+        )
124+        original_inv_freq = self.original_inv_freq.to(device)
125+
126+        # PATCHED: uses torch.cond instead of a test
127+        cond = (seq_len > original_max_position_embeddings).item()
128+        inv_freq = torch.cond(
129+            cond,
130+            (lambda x, y: x.clone()),
131+            (lambda x, y: y.clone()),
132+            [long_inv_freq, original_inv_freq],
133+        )
134+        setattr(self, f"{prefix}inv_freq", inv_freq)
135+        # if seq_len > original_max_position_embeddings:
136+        #    self.inv_freq = self.long_inv_freq
137+        # else:
138+        #    self.inv_freq = self.original_inv_freq
139
140     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
141-        """
142-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
143-        1 - growing beyond the cached sequence length (allow scaling)
144-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
145-        """
146+        # constructor:
147+        # - self.max_seq_len_cached = config.max_position_embeddings
148+        # - self.original_max_seq_len = config.max_position_embeddings
149+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
150+
151+        # It is no use to patch the function after the model is created
152+        # as rope_init_fn is an attribute set to one function when the model
153+        # is created and when no patch is applied yet.
154+        # So we select the patched version here.
155+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
156+
157+        # This behaviour is difficult to translate.
158+        # The sequence always grows.
159+        # The test should always True.
160+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
161+        #
162+        # if seq_len > self.max_seq_len_cached:  # growth
163+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
164+        #        self.config, device, seq_len=seq_len
165+        #    )
166+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
167+        #    self.max_seq_len_cached = seq_len
168+        #
169+        # So we should not need what follows.
170+        #
171+        # cond = (seq_len > self.max_seq_len_cached).item()
172+        # self.attention_scaling = torch.cond(
173+        #    cond,
174+        #    (lambda x, y: x.clone()),
175+        #    (lambda x, y: y.clone()),
176+        #    [attention_scaling, self.attention_scaling],
177+        # )
178+
179         seq_len = torch.max(position_ids) + 1
180+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
181+
182         if layer_type is None:
183-            rope_type = self.rope_type
184-            max_seq_len_cached = self.max_seq_len_cached
185+            # rope_type = self.rope_type
186+            # max_seq_len_cached = self.max_seq_len_cached
187             original_inv_freq = self.original_inv_freq
188             prefix = ""
189         else:
190-            rope_type = self.rope_type[layer_type]
191-            max_seq_len_cached = getattr(
192-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
193-            )
194+            # rope_type = self.rope_type[layer_type]
195+            # max_seq_len_cached = getattr(
196+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
197+            # )
198             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
199             prefix = f"{layer_type}_"
200
201-        if seq_len > max_seq_len_cached:  # growth
202-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
203-            inv_freq, self.attention_scaling = rope_init_fn(
204-                self.config,
205-                device,
206-                seq_len=seq_len,
207-                layer_type=layer_type,
208-            )
209-            # TODO joao: may break with compilation
210-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
211-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
212+        # Second test to translate.
213+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
214+        # But in that case the following condition is a way to restore the original cache.
215
216-        if (
217-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
218-        ):  # reset
219-            # This .to() is needed if the model has been moved to a device after being initialized (because
220-            # the buffer is automatically moved, but not the original copy)
221-            original_inv_freq = original_inv_freq.to(device)
222-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
223-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
224-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
225+        # if (
226+        #    seq_len < self.original_max_seq_len
227+        #    and self.max_seq_len_cached > self.original_max_seq_len
228+        # ):
229+        #    self.original_inv_freq = self.original_inv_freq.to(device)
230+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
231+        #    self.max_seq_len_cached = self.original_max_seq_len
232+
233+        original_inv_freq = self.original_inv_freq.to(device)
234+        cond = (seq_len >= self.original_max_seq_len).item()
235+        # PATCHED: uses torch.cond instead of a test
236+        inv_freq = torch.cond(
237+            cond,
238+            (lambda x, y: x.clone()),
239+            (lambda x, y: y.clone()),
240+            [long_inv_freq, original_inv_freq],
241+        )
242+        setattr(self, f"{prefix}inv_freq", inv_freq)
243
244     @wraps(rope_forward)
245     def wrapper(self, x, position_ids, layer_type=None):
246-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
247-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
248-        if "dynamic" in rope_type:
249-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
250-        elif rope_type == "longrope":
251-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
252-        return rope_forward(self, x, position_ids, **kwargs)
253+        if layer_type is None:
254+            if "dynamic" in self.rope_type:
255+                dynamic_frequency_update(self, position_ids, device=x.device)
256+            elif self.rope_type == "longrope":
257+                longrope_frequency_update(self, position_ids, device=x.device)
258+            return rope_forward(self, x, position_ids)
259+
260+        if "dynamic" in self.rope_type:
261+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
262+        elif self.rope_type == "longrope":
263+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
264+        return rope_forward(self, x, position_ids, layer_type=layer_type)
265
266     return wrapper

transformers: _vmap_for_bhqkv -> patched__vmap_for_bhqkv

 1--- original
 2+++ rewritten
 3@@ -1,25 +1,50 @@
 4-def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
 5-    """
 6-    Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
 7-    the batch and head indices as well if `bh_indices=True`.
 8-    Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
 9-    functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
10+def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
11+    """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
12+    from ...helpers import string_type
13
14-    Args:
15-        mask_function (`Callable`):
16-            The mask_function to vmap.
17-        bh_indices (`bool`, optional):
18-            Whether to vmap over the batch and head indices as well, or only q and kv indices.
19+    dimensions: List[Tuple[Optional[int], ...]] = [
20+        (None, None, None, 0),
21+        (None, None, 0, None),
22+    ]
23+    if bh_indices:
24+        dimensions.extend([(None, 0, None, None), (0, None, None, None)])
25+    # reshape
26+    dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
27+    dimensions = tuple(reversed(dimensions))
28+    indices = tuple(shape.index(-1) for shape in dimensions)
29
30-    Returns:
31-        Callable: The vmapped function.
32-    """
33-    # We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
34-    dimensions = [(None, None, None, 0), (None, None, 0, None)]
35-    if bh_indices:
36-        # We extend broadcasting over the [batch_idx, head_idx] dimensions
37-        dimensions.extend([(None, 0, None, None), (0, None, None, None)])
38+    # unsqueeze
39+    udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
40
41-    for dims in dimensions:
42-        mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
43-    return mask_function
44+    def vector_mask_function(
45+        *args, mask_function=mask_function, dimensions=dimensions, indices=indices
46+    ):
47+        assert len(args) == len(dimensions) == len(udimensions), (
48+            f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
49+            f"and udimensions={udimensions}."
50+        )
51+        assert len(indices) == len(args), (
52+            f"Mismatch between args={string_type(args)} and indices={indices}, "
53+            f"they should have the same length."
54+        )
55+        for a in args:
56+            assert (
57+                a.ndim == 1
58+            ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
59+            torch._check(a.shape[0] > 0)
60+
61+        new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
62+        # new_args = [
63+        #    a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
64+        #    for a, dims in zip(args, udimensions)
65+        # ]
66+        max_shape = tuple(args[i].shape[0] for i in indices)
67+        # if _is_torchdynamo_exporting():
68+        #     for a in args:
69+        #         # The exporter should export with a dimension > 1
70+        #         # to make sure it is dynamic.
71+        #         torch._check(a.shape[0] > 1)
72+        expanded_args = [a.expand(max_shape) for a in new_args]
73+        return mask_function(*expanded_args)
74+
75+    return vector_mask_function

transformers: sdpa_mask_recent_torch -> patched_sdpa_mask_recent_torch

  1--- original
  2+++ rewritten
  3@@ -1,4 +1,4 @@
  4-def sdpa_mask_recent_torch(
  5+def patched_sdpa_mask_recent_torch(
  6     batch_size: int,
  7     cache_position: torch.Tensor,
  8     kv_length: int,
  9@@ -10,145 +10,42 @@
 10     allow_is_bidirectional_skip: bool = False,
 11     **kwargs,
 12 ) -> Optional[torch.Tensor]:
 13-    """
 14-    Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
 15-    the element should take part in the attention computation, and False that it should not.
 16-    This function can only be used with torch>=2.5, as the context manager is otherwise not available.
 17-
 18-    Args:
 19-        batch_size (`int`):
 20-            The batch size of the input sequence.
 21-        cache_position (`torch.Tensor`):
 22-            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
 23-        kv_length (`int`):
 24-            The size that the key and value states will have during the attention computation.
 25-        kv_offset (`int`, optional):
 26-            An optional offset to indicate at which first position the key and values states will refer to.
 27-        mask_function (`Callable`):
 28-            The mask factory function describing the mask pattern.
 29-        attention_mask (`torch.Tensor`, optional):
 30-            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
 31-        local_size (`int`, optional):
 32-            The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
 33-            to try to skip mask creation if possible.
 34-        allow_is_causal_skip (`bool`, optional):
 35-            Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
 36-            `torch.sdpa` instead. Default to `True`.
 37-        allow_torch_fix (`bool`, optional):
 38-            Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
 39-            versions. We need an arg to skip it when using eager. By default `True`.
 40-        allow_is_bidirectional_skip (`bool`, optional):
 41-            Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
 42-            i.e. full attention without any padding. Default to `False`.
 43-
 44-
 45-    ## Creating a simple causal mask:
 46-
 47-    To create the following causal mask:
 48-
 49-        0 ■ ⬚ ⬚ ⬚ ⬚
 50-        1 ■ ■ ⬚ ⬚ ⬚
 51-        2 ■ ■ ■ ⬚ ⬚
 52-        3 ■ ■ ■ ■ ⬚
 53-        4 ■ ■ ■ ■ ■
 54-
 55-    You can do
 56-
 57-    ```python
 58-    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
 59-    >>> tensor([[[[ True, False, False, False, False],
 60-                  [ True,  True, False, False, False],
 61-                  [ True,  True,  True, False, False],
 62-                  [ True,  True,  True,  True, False],
 63-                  [ True,  True,  True,  True,  True]]]])
 64-    ```
 65-
 66-    ## Creating a sliding window mask:
 67-
 68-    To create the following sliding window mask (`sliding_window=3`):
 69-
 70-        0 ■ ⬚ ⬚ ⬚ ⬚
 71-        1 ■ ■ ⬚ ⬚ ⬚
 72-        2 ■ ■ ■ ⬚ ⬚
 73-        3 ⬚ ■ ■ ■ ⬚
 74-        4 ⬚ ⬚ ■ ■ ■
 75-
 76-    You can do
 77-
 78-    ```python
 79-    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
 80-    >>> tensor([[[[ True, False, False, False, False],
 81-                  [ True,  True, False, False, False],
 82-                  [ True,  True,  True, False, False],
 83-                  [False,  True,  True,  True, False],
 84-                  [False, False,  True,  True,  True]]]])
 85-    ```
 86-
 87-    ## Creating a chunked attention mask
 88-
 89-    To create the following chunked attention mask (`chunk_size=3`):
 90-
 91-        0 ■ ⬚ ⬚ ⬚ ⬚
 92-        1 ■ ■ ⬚ ⬚ ⬚
 93-        2 ■ ■ ■ ⬚ ⬚
 94-        3 ⬚ ⬚ ⬚ ■ ⬚
 95-        4 ⬚ ⬚ ⬚ ■ ■
 96-
 97-    You can do
 98-
 99-    ```python
100-    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
101-    >>> tensor([[[[ True, False, False, False, False],
102-                [ True,  True, False, False, False],
103-                [ True,  True,  True, False, False],
104-                [False, False, False,  True, False],
105-                [False, False, False,  True,  True]]]])
106-    ```
107-
108-    """
109+    """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
110     q_length = cache_position.shape[0]
111-    # Potentially pad the 2D mask, and slice it correctly
112     padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
113-
114-    # Under specific conditions, we can avoid materializing the mask
115-    #   1. Causal masks can rely on the `is_causal` argument
116-    #   2. Bidirectional do not need any further processing (no bias)
117     if allow_is_causal_skip and _ignore_causal_mask_sdpa(
118         padding_mask, q_length, kv_length, kv_offset, local_size
119     ):
120         return None
121-    if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
122+    if (
123+        allow_is_bidirectional_skip
124+        and _ignore_bidirectional_mask_sdpa
125+        and _ignore_bidirectional_mask_sdpa(padding_mask)
126+    ):
127         return None
128
129-    # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
130-    # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
131     if mask_function is bidirectional_mask_function:
132         if padding_mask is not None:
133             # used for slicing without data-dependent slicing
134             mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
135             return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
136-        else:
137-            return torch.ones(
138-                batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device
139-            )
140+        return torch.ones(
141+            batch_size,
142+            1,
143+            q_length,
144+            kv_length,
145+            dtype=torch.bool,
146+            device=cache_position.device,
147+        )
148
149-    # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
150-    # but without data-dependent slicing (i.e. torch.compile friendly)
151     kv_arange = torch.arange(kv_length, device=cache_position.device)
152     kv_arange += kv_offset
153-
154-    # Potentially add the padding 2D mask
155     if padding_mask is not None:
156         mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
157-
158     batch_arange = torch.arange(batch_size, device=cache_position.device)
159     head_arange = torch.arange(1, device=cache_position.device)
160-    # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
161-    # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
162-    # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
163-    with TransformGetItemToIndex():
164-        causal_mask = _vmap_for_bhqkv(mask_function)(
165-            batch_arange, head_arange, cache_position, kv_arange
166-        )
167-
168+    # PATCHED: this line calls the patched version of vmap_for_bhqkv
169+    causal_mask = patched__vmap_for_bhqkv(mask_function)(
170+        batch_arange, head_arange, cache_position, kv_arange
171+    )
172     return causal_mask

transformers: sdpa_mask_recent_torch -> patched_sdpa_mask_recent_torch

  1--- original
  2+++ rewritten
  3@@ -1,4 +1,4 @@
  4-def sdpa_mask_recent_torch(
  5+def patched_sdpa_mask_recent_torch(
  6     batch_size: int,
  7     cache_position: torch.Tensor,
  8     kv_length: int,
  9@@ -10,145 +10,42 @@
 10     allow_is_bidirectional_skip: bool = False,
 11     **kwargs,
 12 ) -> Optional[torch.Tensor]:
 13-    """
 14-    Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
 15-    the element should take part in the attention computation, and False that it should not.
 16-    This function can only be used with torch>=2.5, as the context manager is otherwise not available.
 17-
 18-    Args:
 19-        batch_size (`int`):
 20-            The batch size of the input sequence.
 21-        cache_position (`torch.Tensor`):
 22-            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
 23-        kv_length (`int`):
 24-            The size that the key and value states will have during the attention computation.
 25-        kv_offset (`int`, optional):
 26-            An optional offset to indicate at which first position the key and values states will refer to.
 27-        mask_function (`Callable`):
 28-            The mask factory function describing the mask pattern.
 29-        attention_mask (`torch.Tensor`, optional):
 30-            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
 31-        local_size (`int`, optional):
 32-            The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
 33-            to try to skip mask creation if possible.
 34-        allow_is_causal_skip (`bool`, optional):
 35-            Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
 36-            `torch.sdpa` instead. Default to `True`.
 37-        allow_torch_fix (`bool`, optional):
 38-            Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
 39-            versions. We need an arg to skip it when using eager. By default `True`.
 40-        allow_is_bidirectional_skip (`bool`, optional):
 41-            Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
 42-            i.e. full attention without any padding. Default to `False`.
 43-
 44-
 45-    ## Creating a simple causal mask:
 46-
 47-    To create the following causal mask:
 48-
 49-        0 ■ ⬚ ⬚ ⬚ ⬚
 50-        1 ■ ■ ⬚ ⬚ ⬚
 51-        2 ■ ■ ■ ⬚ ⬚
 52-        3 ■ ■ ■ ■ ⬚
 53-        4 ■ ■ ■ ■ ■
 54-
 55-    You can do
 56-
 57-    ```python
 58-    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
 59-    >>> tensor([[[[ True, False, False, False, False],
 60-                  [ True,  True, False, False, False],
 61-                  [ True,  True,  True, False, False],
 62-                  [ True,  True,  True,  True, False],
 63-                  [ True,  True,  True,  True,  True]]]])
 64-    ```
 65-
 66-    ## Creating a sliding window mask:
 67-
 68-    To create the following sliding window mask (`sliding_window=3`):
 69-
 70-        0 ■ ⬚ ⬚ ⬚ ⬚
 71-        1 ■ ■ ⬚ ⬚ ⬚
 72-        2 ■ ■ ■ ⬚ ⬚
 73-        3 ⬚ ■ ■ ■ ⬚
 74-        4 ⬚ ⬚ ■ ■ ■
 75-
 76-    You can do
 77-
 78-    ```python
 79-    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
 80-    >>> tensor([[[[ True, False, False, False, False],
 81-                  [ True,  True, False, False, False],
 82-                  [ True,  True,  True, False, False],
 83-                  [False,  True,  True,  True, False],
 84-                  [False, False,  True,  True,  True]]]])
 85-    ```
 86-
 87-    ## Creating a chunked attention mask
 88-
 89-    To create the following chunked attention mask (`chunk_size=3`):
 90-
 91-        0 ■ ⬚ ⬚ ⬚ ⬚
 92-        1 ■ ■ ⬚ ⬚ ⬚
 93-        2 ■ ■ ■ ⬚ ⬚
 94-        3 ⬚ ⬚ ⬚ ■ ⬚
 95-        4 ⬚ ⬚ ⬚ ■ ■
 96-
 97-    You can do
 98-
 99-    ```python
100-    >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
101-    >>> tensor([[[[ True, False, False, False, False],
102-                [ True,  True, False, False, False],
103-                [ True,  True,  True, False, False],
104-                [False, False, False,  True, False],
105-                [False, False, False,  True,  True]]]])
106-    ```
107-
108-    """
109+    """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
110     q_length = cache_position.shape[0]
111-    # Potentially pad the 2D mask, and slice it correctly
112     padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
113-
114-    # Under specific conditions, we can avoid materializing the mask
115-    #   1. Causal masks can rely on the `is_causal` argument
116-    #   2. Bidirectional do not need any further processing (no bias)
117     if allow_is_causal_skip and _ignore_causal_mask_sdpa(
118         padding_mask, q_length, kv_length, kv_offset, local_size
119     ):
120         return None
121-    if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
122+    if (
123+        allow_is_bidirectional_skip
124+        and _ignore_bidirectional_mask_sdpa
125+        and _ignore_bidirectional_mask_sdpa(padding_mask)
126+    ):
127         return None
128
129-    # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
130-    # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
131     if mask_function is bidirectional_mask_function:
132         if padding_mask is not None:
133             # used for slicing without data-dependent slicing
134             mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
135             return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
136-        else:
137-            return torch.ones(
138-                batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device
139-            )
140+        return torch.ones(
141+            batch_size,
142+            1,
143+            q_length,
144+            kv_length,
145+            dtype=torch.bool,
146+            device=cache_position.device,
147+        )
148
149-    # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
150-    # but without data-dependent slicing (i.e. torch.compile friendly)
151     kv_arange = torch.arange(kv_length, device=cache_position.device)
152     kv_arange += kv_offset
153-
154-    # Potentially add the padding 2D mask
155     if padding_mask is not None:
156         mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
157-
158     batch_arange = torch.arange(batch_size, device=cache_position.device)
159     head_arange = torch.arange(1, device=cache_position.device)
160-    # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
161-    # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
162-    # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
163-    with TransformGetItemToIndex():
164-        causal_mask = _vmap_for_bhqkv(mask_function)(
165-            batch_arange, head_arange, cache_position, kv_arange
166-        )
167-
168+    # PATCHED: this line calls the patched version of vmap_for_bhqkv
169+    causal_mask = patched__vmap_for_bhqkv(mask_function)(
170+        batch_arange, head_arange, cache_position, kv_arange
171+    )
172     return causal_mask

transformers: eager_mask -> patched_eager_mask

 1--- original
 2+++ rewritten
 3@@ -1,4 +1,4 @@
 4-def eager_mask(
 5+def patched_eager_mask(
 6     batch_size: int,
 7     cache_position: torch.Tensor,
 8     kv_length: int,
 9@@ -8,31 +8,12 @@
10     dtype: torch.dtype = torch.float32,
11     **kwargs,
12 ) -> torch.Tensor:
13-    """
14-    Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
15-    the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
16-    it should not.
17-
18-    Args:
19-        batch_size (`int`):
20-            The batch size of the input sequence.
21-        cache_position (`torch.Tensor`):
22-            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
23-        kv_length (`int`):
24-            The size that the key and value states will have during the attention computation.
25-        kv_offset (`int`, optional):
26-            An optional offset to indicate at which first position the key and values states will refer to.
27-        mask_function (`Callable`):
28-            The mask factory function describing the mask pattern.
29-        attention_mask (`torch.Tensor`, optional):
30-            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
31-        dtype (`torch.dtype`, optional):
32-            The dtype to use for the mask. By default, `torch.float32`.
33-    """
34+    """manual patch for function ``transformers.masking_utils.eager_mask``."""
35     # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
36     _ = kwargs.pop("allow_is_causal_skip", None)
37     _ = kwargs.pop("allow_is_bidirectional_skip", None)
38-    mask = sdpa_mask(
39+    # PATCHED: this line called the patched version of sdpa_mask
40+    mask = patched_sdpa_mask_recent_torch(
41         batch_size=batch_size,
42         cache_position=cache_position,
43         kv_length=kv_length,
44@@ -45,6 +26,10 @@
45         **kwargs,
46     )
47     min_dtype = torch.finfo(dtype).min
48-    # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
49-    mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
50+    # PATCHED: the following line
51+    # we need 0s where the tokens should be taken into account,
52+    # and -inf otherwise (mask is already of boolean type)
53+    # mask =
54+    #   torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
55+    mask = (~mask).to(dtype) * min_dtype
56     return mask

transformers: sdpa_attention_forward -> patched_sdpa_attention_forward

  1--- original
  2+++ rewritten
  3@@ -1,4 +1,4 @@
  4-def sdpa_attention_forward(
  5+def patched_sdpa_attention_forward(
  6     module: torch.nn.Module,
  7     query: torch.Tensor,
  8     key: torch.Tensor,
  9@@ -9,60 +9,121 @@
 10     is_causal: Optional[bool] = None,
 11     **kwargs,
 12 ) -> tuple[torch.Tensor, None]:
 13-    if kwargs.get("output_attentions", False):
 14-        logger.warning_once(
 15-            "`sdpa` attention does not support `output_attentions=True`."
 16-            " Please set your attention to `eager` if you want any of these features."
 17-        )
 18+    """
 19+    manual patch for function
 20+    ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
 21+    """
 22+    assert not kwargs.get("output_attentions", False), (
 23+        "`sdpa` attention does not support `output_attentions=True`."
 24+        " Please set your attention to `eager` if you want any of these features."
 25+    )
 26+    torch._check(
 27+        query.shape[0] == key.shape[0] or query.shape[0] == 1,
 28+        lambda: (
 29+            f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
 30+            f"value: {value.shape}"
 31+        ),
 32+    )
 33+    torch._check(
 34+        key.shape[0] == value.shape[0] or key.shape[0] == 1,
 35+        lambda: (
 36+            f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
 37+            f"value: {value.shape}"
 38+        ),
 39+    )
 40+
 41     sdpa_kwargs = {}
 42     if hasattr(module, "num_key_value_groups"):
 43-        if not use_gqa_in_sdpa(attention_mask, key):
 44-            key = repeat_kv(key, module.num_key_value_groups)
 45-            value = repeat_kv(value, module.num_key_value_groups)
 46+        if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
 47+            key = transformers.integrations.sdpa_attention.repeat_kv(
 48+                key, module.num_key_value_groups
 49+            )
 50+            value = transformers.integrations.sdpa_attention.repeat_kv(
 51+                value, module.num_key_value_groups
 52+            )
 53         else:
 54             sdpa_kwargs = {"enable_gqa": True}
 55
 56     if attention_mask is not None and attention_mask.ndim == 4:
 57         attention_mask = attention_mask[:, :, :, : key.shape[-2]]
 58
 59-    # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
 60-    is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
 61+    torch._check(
 62+        attention_mask is None or attention_mask.shape[3] == key.shape[2],
 63+        lambda: "Attention mask shape incompatible with key shape.",
 64+    )
 65
 66-    # SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
 67-    # - Not in decoding phase (otherwise we want full attention on the single query token)
 68-    # - Attention mask is not to be provided (even if it is a causal pattern)
 69-    # - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
 70-    #
 71-    # Quirks on the conditionals:
 72-    # - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
 73-    #   full graph options. Otherwise, dynamic shapes are prevented from compiling.
 74-    # - It is important to check first for the shape, otherwise compile will fail with
 75-    #   `argument 'is_causal' must be bool, not SymBool`.
 76-    is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
 77+    if patch_sdpa_is_causal:
 78+        # transformers>=4.55
 79+        is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
 80
 81-    # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
 82-    # We convert it to a bool for the SDPA kernel that only accepts bools.
 83-    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
 84-        is_causal = is_causal.item()
 85+        # PATCHED: remove the test query.shape[2] > 1
 86+        # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
 87+        # and we split the test to keep the minimum in torch.cond
 88+        is_causal = attention_mask is None and is_causal
 89
 90-    # When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
 91-    # and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
 92-    # This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
 93-    if _is_torch_npu_available:
 94-        if attention_mask is not None and attention_mask.dtype != torch.bool:
 95-            # Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
 96-            attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
 97+        if not is_causal:
 98+            return (
 99+                torch.nn.functional.scaled_dot_product_attention(
100+                    query,
101+                    key,
102+                    value,
103+                    attn_mask=attention_mask,
104+                    dropout_p=dropout,
105+                    scale=scaling,
106+                    is_causal=is_causal,
107+                    **sdpa_kwargs,
108+                )
109+                .transpose(1, 2)
110+                .contiguous(),
111+                None,
112+            )
113+    else:
114+        # transformers<4.55
115+        if is_causal is None and attention_mask is not None:
116+            is_causal = False
117+        if is_causal is not None:
118+            return (
119+                torch.nn.functional.scaled_dot_product_attention(
120+                    query,
121+                    key,
122+                    value,
123+                    attn_mask=attention_mask,
124+                    dropout_p=dropout,
125+                    scale=scaling,
126+                    is_causal=is_causal,
127+                    **sdpa_kwargs,
128+                )
129+                .transpose(1, 2)
130+                .contiguous(),
131+                None,
132+            )
133
134-    attn_output = torch.nn.functional.scaled_dot_product_attention(
135-        query,
136-        key,
137-        value,
138-        attn_mask=attention_mask,
139-        dropout_p=dropout,
140-        scale=scaling,
141-        is_causal=is_causal,
142-        **sdpa_kwargs,
143+    # To avoid the following errors:
144+    # is_causal=query.shape[2] > 1
145+    # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
146+    # is_causal=torch.tensor(query.shape[2] > 1)
147+    # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
148+    attn_output = torch.cond(
149+        query.shape[2] > 1,  # distinction between prefill and decoding steps
150+        lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
151+            query,
152+            key,
153+            value,
154+            dropout_p=dropout,
155+            scale=scaling,
156+            is_causal=True,
157+            **sdpa_kwargs,
158+        ),
159+        lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
160+            query,
161+            key,
162+            value,
163+            dropout_p=dropout,
164+            scale=scaling,
165+            is_causal=False,
166+            **sdpa_kwargs,
167+        ),
168+        [query, key, value],
169     )
170     attn_output = attn_output.transpose(1, 2).contiguous()
171-
172     return attn_output, None

auto/patch_transformers: DynamicLayer.lazy_initialization -> patched_DynamicLayer.lazy_initialization

 1--- original
 2+++ rewritten
 3@@ -1,5 +1,9 @@
 4 def lazy_initialization(self, key_states: torch.Tensor):
 5     self.dtype, self.device = key_states.dtype, key_states.device
 6-    self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
 7-    self.values = torch.tensor([], dtype=self.dtype, device=self.device)
 8-    self.is_initialized = True
 9+    new_shape = list(key_states.shape)
10+    new_shape[-2] = 0
11+    # PATCHED: used a tensor with an empty shape and not en empty list to initialize
12+    self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
13+    self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
14+    if patch_is_initialized:
15+        self.is_initialized = True

auto/patch_transformers: Gemma2RotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: Gemma3Model.get_placeholder_mask -> patched_Gemma3Model.get_placeholder_mask

 1--- original
 2+++ rewritten
 3@@ -4,14 +4,12 @@
 4     inputs_embeds: torch.FloatTensor,
 5     image_features: torch.FloatTensor,
 6 ):
 7-    """
 8-    Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
 9-    equal to the length of multimodal features. If the lengths are different, an error is raised.
10-    """
11     if input_ids is None:
12         special_image_mask = inputs_embeds == self.get_input_embeddings()(
13             torch.tensor(
14-                self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
15+                self.config.image_token_id,
16+                dtype=torch.long,
17+                device=inputs_embeds.device,
18             )
19         )
20         special_image_mask = special_image_mask.all(-1)
21@@ -23,8 +21,14 @@
22         special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
23     )
24     n_image_features = image_features.shape[0] * image_features.shape[1]
25-    if inputs_embeds[special_image_mask].numel() != image_features.numel():
26-        raise ValueError(
27-            f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
28-        )
29+    # PATCHED: torch._check
30+    # if inputs_embeds[special_image_mask].numel() != image_features.numel():
31+    #    raise ValueError( ... )
32+    torch._check(
33+        inputs_embeds[special_image_mask].numel() == image_features.numel(),
34+        lambda: (
35+            f"Image features and image tokens do not match: tokens: "
36+            f"{n_image_tokens}, features {n_image_features}"
37+        ),
38+    )
39     return special_image_mask

auto/patch_transformers: Gemma3RotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,13 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6+@patched_dynamic_rope_update
  7 def forward(self, x, position_ids, layer_type=None):
  8-    inv_freq = getattr(self, f"{layer_type}_inv_freq")
  9-    attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 10+    if layer_type is not None:
 11+        # transformers>=5.0
 12+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 13+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 14+    else:
 15+        # transformers<5.0
 16+        inv_freq = self.inv_freq
 17+        attention_scaling = self.attention_scaling
 18
 19     inv_freq_expanded = (
 20         inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21
 22--- original
 23+++ rewritten
 24@@ -1,99 +1,193 @@
 25-def dynamic_rope_update(rope_forward):
 26-    """
 27-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 28-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 29+def patched_dynamic_rope_update(rope_forward):
 30+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 31
 32-    Args:
 33-        rope_forward (Callable):
 34-            The forward pass of the RoPE implementation.
 35+    ``rope_type`` is determined in the constructor of class
 36+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 37
 38-    Returns:
 39-        The decorated forward pass.
 40+    .. code-block:: python
 41+
 42+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 43+            self.rope_type = config.rope_scaling.get(
 44+                "rope_type", config.rope_scaling.get("type"))
 45+        else:
 46+            self.rope_type = "default"
 47+
 48+    The original code of the patched function:
 49+
 50+    .. code-block:: python
 51+
 52+        def dynamic_rope_update(rope_forward):
 53+            def longrope_frequency_update(self, position_ids, device):
 54+                seq_len = torch.max(position_ids) + 1
 55+                if hasattr(self.config, "original_max_position_embeddings"):
 56+                    original_max_position_embeddings =
 57+                        self.config.original_max_position_embeddings
 58+                else:
 59+                    original_max_position_embeddings =
 60+                        self.config.max_position_embeddings
 61+                if seq_len > original_max_position_embeddings:
 62+                    if not hasattr(self, "long_inv_freq"):
 63+                        self.long_inv_freq, _ = self.rope_init_fn(
 64+                            self.config, device, seq_len=original_max_position_embeddings + 1
 65+                        )
 66+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 67+                else:
 68+                    self.original_inv_freq = self.original_inv_freq.to(device)
 69+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 70+
 71+            def dynamic_frequency_update(self, position_ids, device):
 72+                seq_len = torch.max(position_ids) + 1
 73+                if seq_len > self.max_seq_len_cached:  # growth
 74+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 75+                        self.config, device, seq_len=seq_len)
 76+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 77+                    self.max_seq_len_cached = seq_len
 78+
 79+                if seq_len < self.original_max_seq_len and
 80+                        self.max_seq_len_cached > self.original_max_seq_len:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+                    self.max_seq_len_cached = self.original_max_seq_len
 84+
 85+            @wraps(rope_forward)
 86+            def wrapper(self, x, position_ids):
 87+                if "dynamic" in self.rope_type:
 88+                    dynamic_frequency_update(self, position_ids, device=x.device)
 89+                elif self.rope_type == "longrope":
 90+                    longrope_frequency_update(self, position_ids, device=x.device)
 91+                return rope_forward(self, x, position_ids)
 92+
 93+            return wrapper
 94+
 95     """
 96
 97     def longrope_frequency_update(self, position_ids, device, layer_type=None):
 98-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
 99+        # It is no use to patch the function after the model is created
100+        # as rope_init_fn is an attribute set to one function when the model
101+        # is created and when no patch is applied yet.
102+        # So we select the patched version here.
103+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
104         seq_len = torch.max(position_ids) + 1
105-        original_max_position_embeddings = getattr(
106-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
107-        )
108+        if hasattr(self.config, "original_max_position_embeddings"):
109+            original_max_position_embeddings = self.config.original_max_position_embeddings
110+        else:
111+            original_max_position_embeddings = self.config.max_position_embeddings
112+
113         if layer_type is None:
114-            rope_type = self.rope_type
115+            # rope_type = self.rope_type
116             original_inv_freq = self.original_inv_freq
117             prefix = ""
118         else:
119-            rope_type = self.rope_type[layer_type]
120+            # rope_type = self.rope_type[layer_type]
121             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
122             prefix = f"{layer_type}_"
123
124-        if seq_len > original_max_position_embeddings:
125-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
126-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
127-                long_inv_freq, _ = rope_init_fn(
128-                    self.config,
129-                    device,
130-                    seq_len=original_max_position_embeddings + 1,
131-                    layer_type=layer_type,
132-                )
133-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
134-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
135-        else:
136-            # This .to() is needed if the model has been moved to a device after being initialized (because
137-            # the buffer is automatically moved, but not the original copy)
138-            original_inv_freq = original_inv_freq.to(device)
139-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
140-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
141+        # At export time, seq_len is unknown.
142+        long_inv_freq, _ = rope_init_fn(
143+            self.config, device, seq_len=original_max_position_embeddings + 1
144+        )
145+        original_inv_freq = self.original_inv_freq.to(device)
146+
147+        # PATCHED: uses torch.cond instead of a test
148+        cond = (seq_len > original_max_position_embeddings).item()
149+        inv_freq = torch.cond(
150+            cond,
151+            (lambda x, y: x.clone()),
152+            (lambda x, y: y.clone()),
153+            [long_inv_freq, original_inv_freq],
154+        )
155+        setattr(self, f"{prefix}inv_freq", inv_freq)
156+        # if seq_len > original_max_position_embeddings:
157+        #    self.inv_freq = self.long_inv_freq
158+        # else:
159+        #    self.inv_freq = self.original_inv_freq
160
161     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
162-        """
163-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
164-        1 - growing beyond the cached sequence length (allow scaling)
165-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
166-        """
167+        # constructor:
168+        # - self.max_seq_len_cached = config.max_position_embeddings
169+        # - self.original_max_seq_len = config.max_position_embeddings
170+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
171+
172+        # It is no use to patch the function after the model is created
173+        # as rope_init_fn is an attribute set to one function when the model
174+        # is created and when no patch is applied yet.
175+        # So we select the patched version here.
176+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
177+
178+        # This behaviour is difficult to translate.
179+        # The sequence always grows.
180+        # The test should always True.
181+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
182+        #
183+        # if seq_len > self.max_seq_len_cached:  # growth
184+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
185+        #        self.config, device, seq_len=seq_len
186+        #    )
187+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
188+        #    self.max_seq_len_cached = seq_len
189+        #
190+        # So we should not need what follows.
191+        #
192+        # cond = (seq_len > self.max_seq_len_cached).item()
193+        # self.attention_scaling = torch.cond(
194+        #    cond,
195+        #    (lambda x, y: x.clone()),
196+        #    (lambda x, y: y.clone()),
197+        #    [attention_scaling, self.attention_scaling],
198+        # )
199+
200         seq_len = torch.max(position_ids) + 1
201+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
202+
203         if layer_type is None:
204-            rope_type = self.rope_type
205-            max_seq_len_cached = self.max_seq_len_cached
206+            # rope_type = self.rope_type
207+            # max_seq_len_cached = self.max_seq_len_cached
208             original_inv_freq = self.original_inv_freq
209             prefix = ""
210         else:
211-            rope_type = self.rope_type[layer_type]
212-            max_seq_len_cached = getattr(
213-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
214-            )
215+            # rope_type = self.rope_type[layer_type]
216+            # max_seq_len_cached = getattr(
217+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
218+            # )
219             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
220             prefix = f"{layer_type}_"
221
222-        if seq_len > max_seq_len_cached:  # growth
223-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
224-            inv_freq, self.attention_scaling = rope_init_fn(
225-                self.config,
226-                device,
227-                seq_len=seq_len,
228-                layer_type=layer_type,
229-            )
230-            # TODO joao: may break with compilation
231-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
232-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
233+        # Second test to translate.
234+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
235+        # But in that case the following condition is a way to restore the original cache.
236
237-        if (
238-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
239-        ):  # reset
240-            # This .to() is needed if the model has been moved to a device after being initialized (because
241-            # the buffer is automatically moved, but not the original copy)
242-            original_inv_freq = original_inv_freq.to(device)
243-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
244-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
246+        # if (
247+        #    seq_len < self.original_max_seq_len
248+        #    and self.max_seq_len_cached > self.original_max_seq_len
249+        # ):
250+        #    self.original_inv_freq = self.original_inv_freq.to(device)
251+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
252+        #    self.max_seq_len_cached = self.original_max_seq_len
253+
254+        original_inv_freq = self.original_inv_freq.to(device)
255+        cond = (seq_len >= self.original_max_seq_len).item()
256+        # PATCHED: uses torch.cond instead of a test
257+        inv_freq = torch.cond(
258+            cond,
259+            (lambda x, y: x.clone()),
260+            (lambda x, y: y.clone()),
261+            [long_inv_freq, original_inv_freq],
262+        )
263+        setattr(self, f"{prefix}inv_freq", inv_freq)
264
265     @wraps(rope_forward)
266     def wrapper(self, x, position_ids, layer_type=None):
267-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
268-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
269-        if "dynamic" in rope_type:
270-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
271-        elif rope_type == "longrope":
272-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
273-        return rope_forward(self, x, position_ids, **kwargs)
274+        if layer_type is None:
275+            if "dynamic" in self.rope_type:
276+                dynamic_frequency_update(self, position_ids, device=x.device)
277+            elif self.rope_type == "longrope":
278+                longrope_frequency_update(self, position_ids, device=x.device)
279+            return rope_forward(self, x, position_ids)
280+
281+        if "dynamic" in self.rope_type:
282+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
283+        elif self.rope_type == "longrope":
284+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
285+        return rope_forward(self, x, position_ids, layer_type=layer_type)
286
287     return wrapper

auto/patch_transformers: GemmaRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation -> patched_GenerationMixin._cache_dependant_input_preparation

 1--- original
 2+++ rewritten
 3@@ -3,23 +3,29 @@
 4     input_ids: torch.LongTensor,
 5     inputs_embeds: Optional[torch.FloatTensor],
 6     cache_position: Optional[torch.LongTensor],
 7-) -> tuple[torch.FloatTensor, torch.LongTensor]:
 8+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
 9     """
10     Generic cache-dependent input preparation
11     The code is put in a separate function to allow granular unit testing
12     as it needs a different implementation to be exportable.
13
14-    If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
15-    - Exception 1: when passing input_embeds, input_ids may be missing entries
16-    - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
17-    - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
18-    - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
19-      generate the first token for each sequence. Later use the generated Input ids for continuation.
20+    If we have cache: let's slice `input_ids` through `cache_position`,
21+    to keep only the unprocessed tokens
22+    - Exception 1: when passing input_embeds,
23+      input_ids may be missing entries
24+    - Exception 2: some generation methods do special slicing of input_ids,
25+      so we don't need to do it here
26+    - Exception 3: with synced GPUs cache_position may go out of bounds,
27+      but we only want dummy token in that case.
28+    - Exception 4: If input_embeds are passed then slice it through
29+      `cache_position`, to keep only the unprocessed tokens and
30+      generate the first token for each sequence.
31+      Later use the generated Input ids for continuation.
32
33     The current implementation does not rely on ``self`` and could be
34     a class method. It is left as a standard method to be easily rewritten.
35     """
36-    if is_torchdynamo_exporting():
37+    if _is_torchdynamo_exporting():
38         return self._cache_dependant_input_preparation_exporting(
39             input_ids, inputs_embeds, cache_position
40         )

auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation_exporting -> patched_GenerationMixin._cache_dependant_input_preparation_exporting

 1--- original
 2+++ rewritten
 3@@ -3,7 +3,7 @@
 4     input_ids: torch.LongTensor,
 5     inputs_embeds: Optional[torch.FloatTensor],
 6     cache_position: Optional[torch.LongTensor],
 7-) -> tuple[torch.FloatTensor, torch.LongTensor]:
 8+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
 9     """
10     This method implements method ``_cache_dependant_input_preparation``
11     with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
12@@ -21,22 +21,21 @@
13         #     else:
14         #         if input_ids.shape[1] != cache_position.shape[0]:
15         #             input_ids = input_ids[:, cache_position]
16-        # We need to clone the outputs to avoid aliasing.
17         def branch_1(inputs_embeds, cache_position):
18-            return inputs_embeds[:, -cache_position.shape[0] :].clone()
19+            return inputs_embeds[:, -cache_position.shape[0] :]
20
21         def branch_2(input_ids, cache_position):
22-            return input_ids[:, -cache_position.shape[0] :].clone()
23+            return input_ids[:, -cache_position.shape[0] :]
24
25         def branch_3(input_ids, cache_position):
26-            return input_ids[:, cache_position].clone()
27+            return input_ids[:, cache_position]
28
29         inputs_embeds, input_ids = torch.cond(
30             input_ids.shape[1] == 0,
31             (
32                 lambda input_ids, inputs_embeds, cache_position: (
33                     branch_1(inputs_embeds, cache_position),
34-                    input_ids.clone(),
35+                    input_ids,
36                 )
37             ),
38             (
39@@ -49,7 +48,7 @@
40                             torch.cond(
41                                 input_ids.shape[1] != cache_position.shape[0],
42                                 branch_3,
43-                                (lambda input_ids, cache_position: input_ids.clone()),
44+                                (lambda input_ids, cache_position: input_ids),
45                                 [input_ids, cache_position],
46                             )
47                         ),

auto/patch_transformers: IdeficsAttention.forward -> patched_IdeficsAttention.forward

 1--- original
 2+++ rewritten
 3@@ -4,10 +4,12 @@
 4     key_value_states: Optional[torch.Tensor] = None,
 5     attention_mask: Optional[torch.Tensor] = None,
 6     position_ids: Optional[torch.LongTensor] = None,
 7-    past_key_values: Optional[Cache] = None,
 8+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 9+    output_attentions: bool = False,
10+    use_cache: bool = False,
11     cache_position: Optional[torch.LongTensor] = None,
12-    **kwargs: Unpack[TransformersKwargs],
13-) -> tuple[torch.Tensor, torch.Tensor]:
14+    **kwargs,
15+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
16     # if key_value_states are provided this layer is used as a cross-attention layer
17     is_cross_attention = self.is_cross_attention or key_value_states is not None
18
19@@ -43,20 +45,27 @@
20         )
21
22     kv_seq_len = key_states.shape[-2]
23-    if past_key_values is not None:
24+    if past_key_value is not None:
25         kv_seq_len += cache_position[0]
26
27     if not is_cross_attention:
28-        cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
29-        query_states, key_states = apply_rotary_pos_emb(
30-            query_states, key_states, cos, sin, position_ids
31+        rotary_length = torch.maximum(
32+            torch.tensor(kv_seq_len, dtype=torch.int64),
33+            torch.tensor(q_len, dtype=torch.int64),
34+        )
35+        cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
36+        query_states, key_states = (
37+            transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
38+                query_states, key_states, cos, sin, position_ids
39+            )
40         )
41     # [bsz, nh, t, hd]
42
43-    if past_key_values is not None:
44-        # sin and cos are specific to RoPE models; cache_position needed for the static cache
45+    if past_key_value is not None:
46+        # sin and cos are specific to RoPE models;
47+        # cache_position needed for the static cache
48         cache_kwargs = {"cache_position": cache_position}
49-        key_states, value_states = past_key_values.update(
50+        key_states, value_states = past_key_value.update(
51             key_states, value_states, self.layer_idx, cache_kwargs
52         )
53
54@@ -64,10 +73,22 @@
55         query_states = self.q_layer_norm(query_states)
56         key_states = self.k_layer_norm(key_states)
57
58-    attention_interface: Callable = eager_attention_forward
59+    attention_interface: Callable = (
60+        transformers.models.idefics.modeling_idefics.eager_attention_forward
61+    )
62
63     if self.config._attn_implementation != "eager":
64-        attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
65+        if self.config._attn_implementation == "sdpa" and output_attentions:
66+            transformers.models.idefics.modeling_idefics.logger.warning_once(
67+                "`torch.nn.functional.scaled_dot_product_attention` does not support "
68+                "`output_attentions=True`. Falling back to "
69+                "eager attention. This warning can be removed using the argument "
70+                '`attn_implementation="eager"` when loading the model.'
71+            )
72+        else:
73+            attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
74+                self.config._attn_implementation
75+            ]
76
77     attn_output, attn_weights = attention_interface(
78         self,
79@@ -83,4 +104,9 @@
80     attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
81     attn_output = self.o_proj(attn_output)
82
83+    if output_attentions:
84+        attn_weights = None
85+
86+    if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
87+        return attn_output, attn_weights, past_key_value
88     return attn_output, attn_weights

auto/patch_transformers: IdeficsEmbedding.forward -> patched_IdeficsEmbedding.forward

 1--- original
 2+++ rewritten
 3@@ -1,9 +1,26 @@
 4 def forward(self, x, seq_len=None):
 5     # x: [bs, num_attention_heads, seq_len, head_size]
 6-    if seq_len > self.max_seq_len_cached:
 7-        self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 8+    # if seq_len > self.max_seq_len_cached:
 9+    #    self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
10
11-    return (
12-        self.cos_cached[:seq_len].to(dtype=x.dtype),
13-        self.sin_cached[:seq_len].to(dtype=x.dtype),
14+    def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
15+        t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
16+        # freqs = torch.einsum("i,j->ij", t, inv_freq)
17+        freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
18+        emb = torch.cat((freqs, freqs), dim=-1)
19+        return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
20+
21+    def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
22+        torch._check(seq_len.item() <= cos_cached.shape[0])
23+        co = cos_cached[: seq_len.item()].detach().clone()
24+        torch._check(seq_len.item() <= sin_cached.shape[0])
25+        si = sin_cached[: seq_len.item()].detach().clone()
26+        return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
27+
28+    cos_cached, sin_cached = torch.cond(
29+        (seq_len > self.max_seq_len_cached).item(),
30+        _set_cos_sin_cache_then,
31+        _set_cos_sin_cache_else,
32+        [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
33     )
34+    return cos_cached, sin_cached

auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: MistralRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: MixtralRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: Phi3RotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: Phi4MultimodalRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: PhiRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: Qwen3MoeSparseMoeBlock.forward -> patched_Qwen3MoeSparseMoeBlock.forward

 1--- original
 2+++ rewritten
 3@@ -1,9 +1,67 @@
 4-def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 5+def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 6+    """ """
 7     batch_size, sequence_length, hidden_dim = hidden_states.shape
 8-    hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
 9-    router_logits = self.gate(hidden_states_reshaped)
10-    selected_experts, routing_weights = self.route_tokens_to_experts(
11-        hidden_states_reshaped, router_logits
12+    hidden_states = hidden_states.view(-1, hidden_dim)
13+    # router_logits: (batch * sequence_length, n_experts)
14+    router_logits = self.gate(hidden_states)
15+
16+    routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
17+    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
18+    if self.norm_topk_prob:  # only diff with mixtral sparse moe block!
19+        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
20+    # we cast back to the input dtype
21+    routing_weights = routing_weights.to(hidden_states.dtype)
22+
23+    final_hidden_states = torch.zeros(
24+        (batch_size * sequence_length, hidden_dim),
25+        dtype=hidden_states.dtype,
26+        device=hidden_states.device,
27     )
28-    final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
29-    return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
30+
31+    # One hot encode the selected experts to create an expert mask
32+    # this will be used to easily index which expert is going to be sollicitated
33+    expert_mask = torch.nn.functional.one_hot(
34+        selected_experts, num_classes=self.num_experts
35+    ).permute(2, 1, 0)
36+
37+    # Loop over all available experts in the model
38+    # and perform the computation on each expert
39+    expert_sum = expert_mask.sum(dim=(-1, -2))
40+    # expert_hit = torch.greater(expert_sum, 0).nonzero()
41+    # for expert_idx in expert_hit:
42+    for expert_idx in range(self.num_experts):
43+        # initial code has a squeeze but it is not possible to do that.
44+        # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
45+        expert_mask_idx = expert_mask[expert_idx]
46+        final_hidden_states = torch.cond(
47+            (expert_sum[expert_idx] > 0).item(),
48+            lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop(  # noqa: E501
49+                final_hidden_states,
50+                expert_mask,
51+                hidden_states,
52+                routing_weights,
53+                expert_idx=_i,
54+            ),
55+            lambda final_hidden_states, *args: final_hidden_states.clone(),
56+            [final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
57+        )
58+
59+        # if expert_sum[expert_idx] > 0:
60+        #    idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
61+
62+        # Index the correct hidden states and compute the expert hidden state for
63+        # the current expert. We need to make sure to multiply the output hidden
64+        # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
65+        #    current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
66+        #    current_hidden_states = (
67+        #        expert_layer(current_state) * routing_weights[top_x, idx, None]
68+        #    )
69+
70+        # However `index_add_` only support torch tensors for indexing so we'll use
71+        # the `top_x` tensor here.
72+        #    final_hidden_states.index_add_(
73+        #        0, top_x, current_hidden_states.to(hidden_states.dtype)
74+        #    )
75+
76+    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
77+    return final_hidden_states, router_logits

auto/patch_transformers: ‘Qwen3MoeSparseMoeBlock_forward_expert_loop’ -> patched_Qwen3MoeSparseMoeBlock._forward_expert_loop

 1def _forward_expert_loop(
 2    self,
 3    final_hidden_states,
 4    expert_mask_idx,
 5    hidden_states,
 6    routing_weights,
 7    expert_idx: int,
 8):
 9    # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
10    idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
11    hidden_dim = hidden_states.shape[-1]
12    current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
13    expert_current_state = self.experts[expert_idx](current_state)
14    current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
15    return final_hidden_states.index_add(0, top_x, current_hidden_states.to(hidden_states.dtype))

auto/patch_transformers: SamMaskDecoder.forward -> patched_SamMaskDecoder.forward

 1--- original
 2+++ rewritten
 3@@ -5,6 +5,7 @@
 4     sparse_prompt_embeddings: torch.Tensor,
 5     dense_prompt_embeddings: torch.Tensor,
 6     multimask_output: bool,
 7+    output_attentions: Optional[bool] = None,
 8     attention_similarity: Optional[torch.Tensor] = None,
 9     target_embedding: Optional[torch.Tensor] = None,
10 ) -> tuple[torch.Tensor, torch.Tensor]:
11@@ -22,19 +23,31 @@
12             the embeddings of the mask inputs
13         multimask_output (bool):
14             Whether to return multiple masks or a single mask.
15+        output_attentions (bool, *optional*):
16+            Whether or not to return the attentions tensors of all attention layers.
17     """
18     batch_size, num_channels, height, width = image_embeddings.shape
19-    point_batch_size = (
20-        sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
21-    )
22+    point_batch_size = sparse_prompt_embeddings.shape[1]
23     # Concatenate output tokens
24     output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
25     output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
26
27-    if sparse_prompt_embeddings is not None:
28-        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
29-    else:
30-        tokens = output_tokens
31+    # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
32+    # torch.any is needed to avoid data-dependent control flow
33+    # with sparse_prompt_embeddings.sum().item() != 0
34+    def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
35+        return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
36+
37+    def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
38+        return output_tokens.clone()
39+
40+    tokens = torch.cond(
41+        torch.any(sparse_prompt_embeddings != 0),
42+        sparse_prompt_embeddings_is_not_empty,
43+        sparse_prompt_embeddings_is_empty,
44+        [output_tokens, sparse_prompt_embeddings],
45+    )
46+
47     point_embeddings = tokens.to(self.iou_token.weight.dtype)
48
49     # Expand per-image data in batch direction to be per-point
50@@ -45,15 +58,21 @@
51     )
52
53     # Run the transformer, image_positional_embedding are consumed
54-    point_embedding, image_embeddings = self.transformer(
55+    torch._check(point_embeddings.shape[0] != 0)
56+    torch._check(point_embeddings.shape[1] != 0)
57+    torch._check(point_embeddings.shape[2] != 0)
58+    torch._check(point_embeddings.shape[3] != 0)
59+    embeddings_attentions = self.transformer(
60         point_embeddings=point_embeddings,
61         image_embeddings=image_embeddings,
62         image_positional_embeddings=image_positional_embeddings,
63         attention_similarity=attention_similarity,
64         target_embedding=target_embedding,
65+        output_attentions=output_attentions,
66     )
67-    iou_token_out = point_embedding[:, :, 0, :]
68-    mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
69+    point_embedding, image_embeddings = embeddings_attentions[:2]
70+    iou_token_out = torch.select(point_embedding, dim=2, index=0)
71+    mask_tokens_out = torch.narrow(point_embedding, dim=2, start=1, length=self.num_mask_tokens)
72
73     # Upscale mask embeddings and predict masks using the mask tokens
74     image_embeddings = image_embeddings.transpose(2, 3).reshape(
75@@ -88,4 +107,15 @@
76         mask_slice = slice(0, 1)
77     masks = masks[:, :, mask_slice, :, :]
78     iou_pred = iou_pred[:, :, mask_slice]
79-    return masks, iou_pred
80+
81+    outputs = (masks, iou_pred)
82+
83+    if len(embeddings_attentions) == 2:
84+        # transformers==4.54
85+        return outputs
86+
87+    if output_attentions and len(embeddings_attentions) > 2:
88+        outputs = outputs + (embeddings_attentions[2],)  # noqa: RUF005
89+    else:
90+        outputs = outputs + (None,)  # noqa: RUF005
91+    return outputs

auto/patch_transformers: SmolLM3RotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

auto/patch_transformers: VisionAttention.forward -> patched_VisionAttention.forward

  1--- original
  2+++ rewritten
  3@@ -3,69 +3,55 @@
  4     hidden_states: torch.Tensor,
  5     cu_seqlens: torch.Tensor,
  6     rotary_pos_emb: Optional[torch.Tensor] = None,
  7-    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  8-    **kwargs,
  9+    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 10 ) -> torch.Tensor:
 11     seq_length = hidden_states.shape[0]
 12-    query_states, key_states, value_states = (
 13+    q, k, v = (
 14         self.qkv(hidden_states)
 15         .reshape(seq_length, 3, self.num_heads, -1)
 16         .permute(1, 0, 2, 3)
 17         .unbind(0)
 18     )
 19-    cos, sin = position_embeddings
 20-    query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
 21+    if position_embeddings is None:
 22+        transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
 23+            "The attention layers in this model are transitioning from "
 24+            " computing the RoPE embeddings internally "
 25+            "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
 26+            "to using externally computed "
 27+            "`position_embeddings` (Tuple of tensors, containing cos and sin)."
 28+            " In v4.54 `rotary_pos_emb` will be "
 29+            "removed and `position_embeddings` will be mandatory."
 30+        )
 31+        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
 32+        cos = emb.cos()
 33+        sin = emb.sin()
 34+    else:
 35+        cos, sin = position_embeddings
 36+    q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
 37+        q, k, cos, sin
 38+    )
 39
 40-    query_states = query_states.transpose(0, 1).unsqueeze(0)
 41-    key_states = key_states.transpose(0, 1).unsqueeze(0)
 42-    value_states = value_states.transpose(0, 1).unsqueeze(0)
 43+    attention_mask = torch.full(
 44+        [1, seq_length, seq_length],
 45+        torch.finfo(q.dtype).min,
 46+        device=q.device,
 47+        dtype=q.dtype,
 48+    )
 49+    # for i in range(1, len(cu_seqlens)):
 50+    #     attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
 51+    #                         cu_seqlens[i - 1] : cu_seqlens[i]] = 0
 52+    attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
 53
 54-    attention_interface: Callable = eager_attention_forward
 55-    if self.config._attn_implementation != "eager":
 56-        attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 57-
 58-    if self.config._attn_implementation == "flash_attention_2":
 59-        # Flash Attention 2: Use cu_seqlens for variable length attention
 60-        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
 61-        attn_output, _ = attention_interface(
 62-            self,
 63-            query_states,
 64-            key_states,
 65-            value_states,
 66-            attention_mask=None,
 67-            scaling=self.scaling,
 68-            dropout=0.0 if not self.training else self.attention_dropout,
 69-            cu_seq_lens_q=cu_seqlens,
 70-            cu_seq_lens_k=cu_seqlens,
 71-            max_length_q=max_seqlen,
 72-            max_length_k=max_seqlen,
 73-            is_causal=False,
 74-            **kwargs,
 75-        )
 76-    else:
 77-        # Other implementations: Process each chunk separately
 78-        lengths = cu_seqlens[1:] - cu_seqlens[:-1]
 79-        splits = [
 80-            torch.split(tensor, lengths.tolist(), dim=2)
 81-            for tensor in (query_states, key_states, value_states)
 82-        ]
 83-
 84-        attn_outputs = [
 85-            attention_interface(
 86-                self,
 87-                q,
 88-                k,
 89-                v,
 90-                attention_mask=None,
 91-                scaling=self.scaling,
 92-                dropout=0.0 if not self.training else self.attention_dropout,
 93-                is_causal=False,
 94-                **kwargs,
 95-            )[0]
 96-            for q, k, v in zip(*splits)
 97-        ]
 98-        attn_output = torch.cat(attn_outputs, dim=1)
 99-
100-    attn_output = attn_output.reshape(seq_length, -1).contiguous()
101+    q = q.transpose(0, 1)
102+    k = k.transpose(0, 1)
103+    v = v.transpose(0, 1)
104+    attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
105+    attn_weights = attn_weights + attention_mask
106+    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
107+        q.dtype
108+    )
109+    attn_output = torch.matmul(attn_weights, v)
110+    attn_output = attn_output.transpose(0, 1)
111+    attn_output = attn_output.reshape(seq_length, -1)
112     attn_output = self.proj(attn_output)
113     return attn_output

auto/patch_transformers: eager_attention_forward -> patched_model_bart_eager_attention_forward

 1--- original
 2+++ rewritten
 3@@ -1,27 +1,23 @@
 4-def eager_attention_forward(
 5-    module: nn.Module,
 6+def patched_model_bart_eager_attention_forward(
 7+    module: torch.nn.Module,
 8     query: torch.Tensor,
 9     key: torch.Tensor,
10     value: torch.Tensor,
11     attention_mask: Optional[torch.Tensor],
12     scaling: Optional[float] = None,
13     dropout: float = 0.0,
14-    **kwargs: Unpack[TransformersKwargs],
15+    head_mask: Optional[torch.Tensor] = None,
16+    **kwargs,
17 ):
18-    if scaling is None:
19-        scaling = query.size(-1) ** -0.5
20-
21-    # Take the dot product between "query" and "key" to get the raw attention scores.
22-    attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24-    if attention_mask is not None:
25-        attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26-        attn_weights = attn_weights + attention_mask
27-
28-    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29-    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31-    attn_output = torch.matmul(attn_weights, value)
32-    attn_output = attn_output.transpose(1, 2).contiguous()
33-
34-    return attn_output, attn_weights
35+    """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
36+    return common_eager_attention_forward(
37+        module,
38+        query,
39+        key,
40+        value,
41+        attention_mask=attention_mask,
42+        scaling=scaling,
43+        dropout=dropout,
44+        head_mask=head_mask,
45+        **kwargs,
46+    )

auto/patch_transformers: eager_attention_forward -> patched_modeling_marian_eager_attention_forward

 1--- original
 2+++ rewritten
 3@@ -1,27 +1,23 @@
 4-def eager_attention_forward(
 5-    module: nn.Module,
 6+def patched_modeling_marian_eager_attention_forward(
 7+    module: torch.nn.Module,
 8     query: torch.Tensor,
 9     key: torch.Tensor,
10     value: torch.Tensor,
11     attention_mask: Optional[torch.Tensor],
12     scaling: Optional[float] = None,
13     dropout: float = 0.0,
14-    **kwargs: Unpack[TransformersKwargs],
15+    head_mask: Optional[torch.Tensor] = None,
16+    **kwargs,
17 ):
18-    if scaling is None:
19-        scaling = query.size(-1) ** -0.5
20-
21-    # Take the dot product between "query" and "key" to get the raw attention scores.
22-    attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24-    if attention_mask is not None:
25-        attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26-        attn_weights = attn_weights + attention_mask
27-
28-    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29-    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31-    attn_output = torch.matmul(attn_weights, value)
32-    attn_output = attn_output.transpose(1, 2).contiguous()
33-
34-    return attn_output, attn_weights
35+    """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
36+    return common_eager_attention_forward(
37+        module,
38+        query,
39+        key,
40+        value,
41+        attention_mask=attention_mask,
42+        scaling=scaling,
43+        dropout=dropout,
44+        head_mask=head_mask,
45+        **kwargs,
46+    )