Patches List#

torch#

<<<

import textwrap
from yobx.torch import apply_patches_for_model
from yobx.torch.in_torch.patches import PATCHES

with apply_patches_for_model(patch_torch=True):
    link = []
    rows = []
    for i, patch in enumerate(PATCHES):
        name = f"patch-torch-{i+1}"
        link.append(f"* :ref:`{name}`")
        rows.extend(
            [
                "",
                f".. _{name}:",
                "",
                patch.patch.__qualname__,
                "-" * len(patch.patch.__qualname__),
                "",
                "::",
                "",
                textwrap.indent(patch.format_diff(), "    "),
                "",
            ]
        )
    print("\n".join([*link, "", *rows]))

>>>

patched_DynamicDimConstraintPrinter._print_Symbol#

torch: DynamicDimConstraintPrinter._print_Symbol -> patched_DynamicDimConstraintPrinter._print_Symbol
--- original
+++ rewritten
@@ -1,6 +1,5 @@
 def _print_Symbol(self, expr: sympy.Symbol) -> str:
-    if not isinstance(expr, sympy.Symbol):
-        raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}")
-    if not self.symbol_to_source.get(expr):
-        raise AssertionError(f"Unknown symbol {expr} created by constraints solver")
-    return self.symbol_to_source[expr][0].name
+    assert isinstance(expr, sympy.Symbol), str(type(expr))
+    if self.symbol_to_source.get(expr):  # type: ignore
+        return self.symbol_to_source[expr][0].name  # type: ignore
+    return str(expr)

patched_infer_size#

torch: infer_size -> patched_infer_size
--- original
+++ rewritten
@@ -1,10 +1,13 @@
-def infer_size(a: Sequence[IntLikeType], b: Sequence[IntLikeType]) -> tuple[IntLikeType, ...]:
-    from torch.fx.experimental.symbolic_shapes import guard_or_false
-
+def patched_infer_size(a, b):
+    """
+    Patches ``torch._subclasses.fake_impls.infer_size``.
+    This patch is needed to export
+    :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`.
+    """
     dimsA = len(a)
     dimsB = len(b)
     ndim = max(dimsA, dimsB)
-    expandedSizes: list[IntLikeType] = [0] * ndim
+    expandedSizes = [0] * ndim
     for i in range(ndim - 1, -1, -1):
         offset = ndim - 1 - i
         dimA = dimsA - 1 - offset
@@ -23,11 +26,21 @@
         # expression of an or statement as-is, without bool()'ing it; if this
         # were not the case, we'd need to write this using torch.sym_or() or
         # something like that).
-        torch._check(
-            guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
-            lambda: f"The size of tensor a ({sizeA}) "
-            f"must match the size of tensor b ({sizeB}) "
-            f"at non-singleton dimension {i})",
-        )
-        expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
+        try:
+            b1 = fx_symbolic_shapes.guard_or_false(sizeA == 1)
+        except fx_symbolic_shapes.GuardOnDataDependentSymNode:
+            b1 = False
+        try:
+            b2 = fx_symbolic_shapes.guard_or_false(sizeB == 1)
+        except fx_symbolic_shapes.GuardOnDataDependentSymNode:
+            b2 = False
+        try:
+            b3 = fx_symbolic_shapes.guard_or_false(sizeA == sizeB)
+        except fx_symbolic_shapes.GuardOnDataDependentSymNode:
+            b3 = False
+        if b1 or b2 or b3:
+            expandedSizes[i] = sizeB if fx_symbolic_shapes.guard_or_false(sizeA == 1) else sizeA
+        else:
+            # PATCHED: generic case, the dimension is known, no need to assert
+            expandedSizes[i] = torch.sym_max(sizeA, sizeB)  # type: ignore
     return tuple(expandedSizes)

patched__broadcast_shapes#

torch: _broadcast_shapes -> patched__broadcast_shapes
--- original
+++ rewritten
@@ -1,12 +1,11 @@
-def _broadcast_shapes(*_shapes):
-    from torch.fx.experimental.symbolic_shapes import (
-        guard_or_false,
-        guarding_hint_or_throw,
-        has_guarding_hint,
-        is_nested_int,
-    )
-
-    backed_so = torch.fx.experimental._config.backed_size_oblivious
+def patched__broadcast_shapes(*_shapes):
+    """
+    Patches ``torch._refs._broadcast_shapes``.
+    This patch is needed to export
+    :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`.
+    """
+    from functools import reduce
+    from torch._prims_common import IntLike

     shapes = tuple(
         (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
@@ -19,63 +18,36 @@
     for shape in shapes:
         if not isinstance(shape, Sequence):
             raise RuntimeError(
-                "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
+                "Input shapes should be of type ints, a tuple of ints, "
+                "or a list of ints, got ",
                 shape,
             )

     # Computes common shape
-    common_shape: list[int | torch.SymInt] = [
-        1,
-    ] * reduce(max, (len(shape) for shape in shapes))
-    for arg_idx, shape in enumerate(shapes):
+    common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
+    for _arg_idx, shape in enumerate(shapes):
         for idx in range(-1, -1 - len(shape), -1):
-            # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
-            if is_nested_int(shape[idx]):
+            if fx_symbolic_shapes.is_nested_int(shape[idx]):
                 # Broadcasting is allowed for (j0, 1) or (j0, j0);
                 # not (j0, j1), (j0, 5), etc.
-                if is_nested_int(common_shape[idx]) and guard_or_false(
-                    shape[idx] == common_shape[idx]
-                ):
+                if fx_symbolic_shapes.is_nested_int(
+                    common_shape[idx]
+                ) and fx_symbolic_shapes.guard_or_false(shape[idx] == common_shape[idx]):
                     continue
             else:
-                # When backed size oblivious is used, we specialize for broadcasting
-                # if its the only way to compile the example input.
-                # i.e: s0:1, s1:1 ==>
-                #           assert s0==s1, no specialization on ==1 or !=1.
-                #            The non-broadcast path is picked
-                #      s0:1, s1:4 ==>
-                #           specialize(s0) to be 1.
-                #      s0:4, s1:1 ==>
-                #           specialize(s1) to be 1.
-                if (
-                    backed_so
-                    and has_guarding_hint(shape[idx])
-                    and has_guarding_hint(common_shape[idx])
-                ):
-                    a = guarding_hint_or_throw(shape[idx])
-                    b = guarding_hint_or_throw(common_shape[idx])
-                    if a == 1 and b != 1:
-                        torch._check(shape[idx] == 1)
-                    if b == 1 and a != 1:
-                        torch._check(common_shape[idx] == 1)
-                if guard_or_false(shape[idx] == common_shape[idx]):
+                if fx_symbolic_shapes.guard_or_false(shape[idx] == common_shape[idx]):
                     continue
-
-            if guard_or_false(common_shape[idx] == 1):
+            # PATCHED: two cases, if == for sure, no broadcast,
+            # otherwise maybe broadcast with max(dimensions)
+            if fx_symbolic_shapes.guard_or_false(common_shape[idx] != 1):
+                pass
+            elif fx_symbolic_shapes.guard_or_false(
+                common_shape[idx] == 1
+            ) or fx_symbolic_shapes.guard_or_false(shape[idx] != 1):
                 if shape[idx] < 0:
                     raise ValueError("Attempting to broadcast a dimension with negative length!")
-                common_shape[idx] = shape[idx]
-
-            if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
-                # broadcast case .
-                continue
+                common_shape[idx] = shape[idx]  # type: ignore
             else:
-                # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
-                torch._check(
-                    common_shape[idx] == shape[idx],
-                    lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
-                    f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
-                    f"should be broadcastable to {common_shape}",
-                )
+                common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])  # type: ignore

     return common_shape

patched__get_range_constraints#

torch: _get_range_constraints -> patched__get_range_constraints
--- original
+++ rewritten
@@ -1,42 +1,36 @@
-def _get_range_constraints(
+def patched__get_range_constraints(
     mod: torch.nn.Module,
-    export_artifact: ExportArtifact,
+    export_artifact: torch.export._trace.ExportArtifact,
     args,
     kwargs,
     dynamic_shapes,
 ):
+    """
+    Patches ``torch.export._trace._get_range_constraints``.
+    See PR `#174593 <https://github.com/pytorch/pytorch/pull/174593>`_.
+    """
     gm: torch.fx.GraphModule = export_artifact.aten.gm
-    export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
+    export_graph_signature: torch.export.graph_signature.ExportGraphSignature = (
+        export_artifact.aten.sig
+    )
     fake_mode: FakeTensorMode = export_artifact.fake_mode
     num_lifted = next(
         (
             i
             for i, s in enumerate(export_graph_signature.input_specs)
-            if s.kind == InputKind.USER_INPUT
+            if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
         ),
         len(export_graph_signature.input_specs),
     )
-    combined_args = _combine_args(mod, args, kwargs)

-    # This is because we trace based on the kwargs passed in from user
+    # preserve_order=True:
+    # this is because we trace based on the kwargs passed in from user
     # not based on the signature. I feel it would be better to just enforce
     # one ordering at the start of tracing to avoid confusions, but that is
     # bigger refactor, so do this to unblock for now.
-    combined_args_traced_order = {}
-    for arg in combined_args:
-        if arg not in kwargs:
-            combined_args_traced_order[arg] = combined_args[arg]
+    combined_args = _combine_args(mod, args, kwargs, preserve_order=True)

-    for key in kwargs:
-        combined_args_traced_order[key] = kwargs[key]
-
-    combined_args = combined_args_traced_order
-
-    range_constraints = make_constraints(
-        fake_mode,
-        gm,
-        combined_args,
-        dynamic_shapes,
-        num_lifted,
+    range_constraints = torch._export.non_strict_utils.make_constraints(
+        fake_mode, gm, combined_args, dynamic_shapes, num_lifted
     )
     return range_constraints

patched__maybe_broadcast#

torch: _maybe_broadcast -> patched__maybe_broadcast
--- original
+++ rewritten
@@ -1,15 +1,18 @@
-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
+    """
+    Patches ``torch._refs._maybe_broadcast``.
+    This patch is needed to export
+    :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`.
+    """
+    from torch._prims_common import ShapeType, TensorLike, Number
+
     # Computes common shape
-    common_shape = _broadcast_shapes(
+    common_shape = patched__broadcast_shapes(
         *(t.shape if isinstance(t, TensorLike) else None for t in args)
     )

     def should_expand(a: ShapeType, b: ShapeType) -> bool:
-        from torch.fx.experimental.symbolic_shapes import (
-            guard_or_false,
-            sym_and,
-            sym_or,
-        )
+        from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and, sym_or

         if len(a) != len(b):
             return True
@@ -29,10 +32,15 @@
                 return True

             # u0==u1 assume the same, no broadcasting!
-            torch._check(
-                x == y,
-                lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
-            )
+            # PATCHED: avoid errors
+            return True  # guard_or_true(x != y)
+            # torch._check(
+            #    x == y,
+            #    lambda x=x, y=y: (
+            #        f"sizes assumed to be the same due to unbacked "
+            #        f"broadcasting semantics x={x!r}, y={y!r}"
+            #    ),
+            # )

         return False

@@ -42,14 +50,14 @@
         elif isinstance(x, Number):
             return x
         elif isinstance(x, TensorLike):
-            if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
+            if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):  # type: ignore
                 return x

-            if should_expand(x.shape, common_shape):
-                return x.expand(common_shape)
+            if should_expand(x.shape, common_shape):  # type: ignore
+                return x.expand(common_shape)  # type: ignore

             return x
         else:
-            raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
+            raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")

     return tuple(__maybe_broadcast(x, common_shape) for x in args)

transformers#

<<<

import textwrap
from yobx.torch import apply_patches_for_model
from yobx.torch.in_torch.patches import PATCHES

with apply_patches_for_model(patch_transformers=True) as details:
    link = []
    rows = []
    for i, patch in enumerate(details):
        name = f"patch-transformers-{i+1}"
        link.append(f"* :ref:`{name}`")
        rows.extend(
            [
                "",
                f".. _{name}:",
                "",
                patch.patch.__qualname__,
                "-" * len(patch.patch.__qualname__),
                "",
                "::",
                "",
                textwrap.indent(patch.format_diff(), "    "),
                "",
            ]
        )
    print("\n".join([*link, "", *rows]))

>>>

<frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute * common_RotaryEmbedding.forward

common_RotaryEmbedding.forward#

transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward
--- original
+++ rewritten
@@ -1,18 +1,26 @@
-@torch.no_grad()
-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
-def forward(self, x, position_ids):
+@patched_dynamic_rope_update
+def forward(self, x, position_ids, layer_type=None):
+    if layer_type is not None:
+        # transformers>=5.0
+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
+    else:
+        # transformers<5.0
+        inv_freq = self.inv_freq
+        attention_scaling = self.attention_scaling
+
     inv_freq_expanded = (
-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
     )
     position_ids_expanded = position_ids[:, None, :].float()

     device_type = (
         x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
     )
-    with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
+    with torch.autocast(device_type=device_type, enabled=False):  # Force float32
         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
         emb = torch.cat((freqs, freqs), dim=-1)
-        cos = emb.cos() * self.attention_scaling
-        sin = emb.sin() * self.attention_scaling
+        cos = emb.cos() * attention_scaling
+        sin = emb.sin() * attention_scaling

     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

--- original
+++ rewritten
@@ -1,103 +1,193 @@
-def dynamic_rope_update(rope_forward):
-    """
-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
+def patched_dynamic_rope_update(rope_forward):
+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``

-    Args:
-        rope_forward (Callable):
-            The forward pass of the RoPE implementation.
+    ``rope_type`` is determined in the constructor of class
+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.

-    Returns:
-        The decorated forward pass.
+    .. code-block:: python
+
+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+            self.rope_type = config.rope_scaling.get(
+                "rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+
+    The original code of the patched function:
+
+    .. code-block:: python
+
+        def dynamic_rope_update(rope_forward):
+            def longrope_frequency_update(self, position_ids, device):
+                seq_len = torch.max(position_ids) + 1
+                if hasattr(self.config, "original_max_position_embeddings"):
+                    original_max_position_embeddings =
+                        self.config.original_max_position_embeddings
+                else:
+                    original_max_position_embeddings =
+                        self.config.max_position_embeddings
+                if seq_len > original_max_position_embeddings:
+                    if not hasattr(self, "long_inv_freq"):
+                        self.long_inv_freq, _ = self.rope_init_fn(
+                            self.config, device, seq_len=original_max_position_embeddings + 1
+                        )
+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
+                else:
+                    self.original_inv_freq = self.original_inv_freq.to(device)
+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+
+            def dynamic_frequency_update(self, position_ids, device):
+                seq_len = torch.max(position_ids) + 1
+                if seq_len > self.max_seq_len_cached:  # growth
+                    inv_freq, self.attention_scaling = self.rope_init_fn(
+                        self.config, device, seq_len=seq_len)
+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
+                    self.max_seq_len_cached = seq_len
+
+                if seq_len < self.original_max_seq_len and
+                        self.max_seq_len_cached > self.original_max_seq_len:
+                    self.original_inv_freq = self.original_inv_freq.to(device)
+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+                    self.max_seq_len_cached = self.original_max_seq_len
+
+            @wraps(rope_forward)
+            def wrapper(self, x, position_ids):
+                if "dynamic" in self.rope_type:
+                    dynamic_frequency_update(self, position_ids, device=x.device)
+                elif self.rope_type == "longrope":
+                    longrope_frequency_update(self, position_ids, device=x.device)
+                return rope_forward(self, x, position_ids)
+
+            return wrapper
+
     """

     def longrope_frequency_update(self, position_ids, device, layer_type=None):
-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
+        # It is no use to patch the function after the model is created
+        # as rope_init_fn is an attribute set to one function when the model
+        # is created and when no patch is applied yet.
+        # So we select the patched version here.
+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
         seq_len = torch.max(position_ids) + 1
+        if hasattr(self.config, "original_max_position_embeddings"):
+            original_max_position_embeddings = self.config.original_max_position_embeddings
+        else:
+            original_max_position_embeddings = self.config.max_position_embeddings

         if layer_type is None:
-            rope_type = self.rope_type
-            original_inv_freq = self.original_inv_freq
-            prefix = ""
-            original_max_position_embeddings = self.config.rope_parameters[
-                "original_max_position_embeddings"
-            ]
-        else:
-            rope_type = self.rope_type[layer_type]
-            original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
-            prefix = f"{layer_type}_"
-            original_max_position_embeddings = self.config.rope_parameters[layer_type][
-                "original_max_position_embeddings"
-            ]
-
-        if seq_len > original_max_position_embeddings:
-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
-                long_inv_freq, _ = rope_init_fn(
-                    self.config,
-                    device,
-                    seq_len=original_max_position_embeddings + 1,
-                    layer_type=layer_type,
-                )
-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
-        else:
-            # This .to() is needed if the model has been moved to a device after being initialized (because
-            # the buffer is automatically moved, but not the original copy)
-            original_inv_freq = original_inv_freq.to(device)
-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
-
-    def dynamic_frequency_update(self, position_ids, device, layer_type=None):
-        """
-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
-        1 - growing beyond the cached sequence length (allow scaling)
-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
-        """
-        seq_len = torch.max(position_ids) + 1
-        if layer_type is None:
-            rope_type = self.rope_type
-            max_seq_len_cached = self.max_seq_len_cached
+            # rope_type = self.rope_type
             original_inv_freq = self.original_inv_freq
             prefix = ""
         else:
-            rope_type = self.rope_type[layer_type]
-            max_seq_len_cached = getattr(
-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
-            )
+            # rope_type = self.rope_type[layer_type]
             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
             prefix = f"{layer_type}_"

-        if seq_len > max_seq_len_cached:  # growth
-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
-            inv_freq, self.attention_scaling = rope_init_fn(
-                self.config,
-                device,
-                seq_len=seq_len,
-                layer_type=layer_type,
-            )
-            # TODO joao: may break with compilation
-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
+        # At export time, seq_len is unknown.
+        long_inv_freq, _ = rope_init_fn(
+            self.config, device, seq_len=original_max_position_embeddings + 1
+        )
+        original_inv_freq = self.original_inv_freq.to(device)

-        if (
-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
-        ):  # reset
-            # This .to() is needed if the model has been moved to a device after being initialized (because
-            # the buffer is automatically moved, but not the original copy)
-            original_inv_freq = original_inv_freq.to(device)
-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
+        # PATCHED: uses torch.cond instead of a test
+        cond = (seq_len > original_max_position_embeddings).item()
+        inv_freq = torch.cond(
+            cond,
+            (lambda x, y: x.clone()),
+            (lambda x, y: y.clone()),
+            [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+        )
+        setattr(self, f"{prefix}inv_freq", inv_freq)
+        # if seq_len > original_max_position_embeddings:
+        #    self.inv_freq = self.long_inv_freq
+        # else:
+        #    self.inv_freq = self.original_inv_freq
+
+    def dynamic_frequency_update(self, position_ids, device, layer_type=None):
+        # constructor:
+        # - self.max_seq_len_cached = config.max_position_embeddings
+        # - self.original_max_seq_len = config.max_position_embeddings
+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+        # It is no use to patch the function after the model is created
+        # as rope_init_fn is an attribute set to one function when the model
+        # is created and when no patch is applied yet.
+        # So we select the patched version here.
+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
+
+        # This behaviour is difficult to translate.
+        # The sequence always grows.
+        # The test should always True.
+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
+        #
+        # if seq_len > self.max_seq_len_cached:  # growth
+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
+        #        self.config, device, seq_len=seq_len
+        #    )
+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
+        #    self.max_seq_len_cached = seq_len
+        #
+        # So we should not need what follows.
+        #
+        # cond = (seq_len > self.max_seq_len_cached).item()
+        # self.attention_scaling = torch.cond(
+        #    cond,
+        #    (lambda x, y: x.clone()),
+        #    (lambda x, y: y.clone()),
+        #    [attention_scaling, self.attention_scaling],
+        # )
+
+        seq_len = torch.max(position_ids) + 1
+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
+
+        if layer_type is None:
+            # rope_type = self.rope_type
+            # max_seq_len_cached = self.max_seq_len_cached
+            original_inv_freq = self.original_inv_freq
+            prefix = ""
+        else:
+            # rope_type = self.rope_type[layer_type]
+            # max_seq_len_cached = getattr(
+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
+            # )
+            original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
+            prefix = f"{layer_type}_"
+
+        # Second test to translate.
+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
+        # But in that case the following condition is a way to restore the original cache.
+
+        # if (
+        #    seq_len < self.original_max_seq_len
+        #    and self.max_seq_len_cached > self.original_max_seq_len
+        # ):
+        #    self.original_inv_freq = self.original_inv_freq.to(device)
+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+        #    self.max_seq_len_cached = self.original_max_seq_len
+
+        original_inv_freq = self.original_inv_freq.to(device)
+        cond = (seq_len >= self.original_max_seq_len).item()
+        # PATCHED: uses torch.cond instead of a test
+        inv_freq = torch.cond(
+            cond,
+            (lambda x, y: x.clone()),
+            (lambda x, y: y.clone()),
+            [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+        )
+        setattr(self, f"{prefix}inv_freq", inv_freq)

     @wraps(rope_forward)
     def wrapper(self, x, position_ids, layer_type=None):
-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
-        if "dynamic" in rope_type:
-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
-        elif rope_type == "longrope":
-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
-        return rope_forward(self, x, position_ids, **kwargs)
+        if layer_type is None:
+            if "dynamic" in self.rope_type:
+                dynamic_frequency_update(self, position_ids, device=x.device)
+            elif self.rope_type == "longrope":
+                longrope_frequency_update(self, position_ids, device=x.device)
+            return rope_forward(self, x, position_ids)
+
+        if "dynamic" in self.rope_type:
+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+        elif self.rope_type == "longrope":
+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+        return rope_forward(self, x, position_ids, layer_type=layer_type)

     return wrapper