Patches Diff¶
Patches are not always needed to export a LLM.
Most of the time, only serialization function are needed to export
a LLM with cache (DynamicCache, …).
Function register_additional_serialization_functions
is enough in many cases.
import torch
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
with register_additional_serialization_functions(patch_transformers=True):
ep = torch.export.export(...)
Function torch_export_patches
helps fixing some issues for many models.
import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(...)
Class PatchDetails
gives an example on how to retrieve the list of involded patches for a specific model.
Those patches belongs to the following list which depends on transformers and
pytorch versions.
<<<
import torch
import transformers
print(torch.__version__, transformers.__version__)
>>>
2.10.0.dev20251208+cu130 5.0.0.dev0
Those two versions leads to the following list of patches.
<<<
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_export_patches import torch_export_patches
details = PatchDetails()
with torch_export_patches(
patch_transformers=True,
patch_torch=True,
patch_diffusers=True,
patch_details=details,
):
pass
for patch in details.patched:
if patch.function_to_patch == patch.patch:
continue
rst = patch.format_diff(format="rst")
print()
print()
print(rst)
print()
print()
>>>
sympy: ‘sympy.core.numbers.IntegerConstant.name’ -> _patch_sympy.<locals>.<lambda>¶
1sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
torch: infer_size -> patched_infer_size¶
1--- original
2+++ rewritten
3@@ -1,4 +1,5 @@
4-def infer_size(a, b):
5+def patched_infer_size(a, b):
6+ """Patches ``torch._subclasses.fake_impls.infer_size``."""
7 from torch.fx.experimental.symbolic_shapes import guard_or_false
8
9 dimsA = len(a)
10@@ -23,11 +24,21 @@
11 # expression of an or statement as-is, without bool()'ing it; if this
12 # were not the case, we'd need to write this using torch.sym_or() or
13 # something like that).
14- torch._check(
15- guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
16- lambda: f"The size of tensor a ({sizeA}) "
17- f"must match the size of tensor b ({sizeB}) "
18- f"at non-singleton dimension {i})",
19- )
20- expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
21+ try:
22+ b1 = guard_or_false(sizeA == 1)
23+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
24+ b1 = False
25+ try:
26+ b2 = guard_or_false(sizeB == 1)
27+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
28+ b2 = False
29+ try:
30+ b3 = guard_or_false(sizeA == sizeB)
31+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
32+ b3 = False
33+ if b1 or b2 or b3:
34+ expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
35+ else:
36+ # PATCHED: generic case, the dimension is known, no need to assert
37+ expandedSizes[i] = torch.sym_max(sizeA, sizeB)
38 return tuple(expandedSizes)
torch: _broadcast_shapes -> patched__broadcast_shapes¶
1--- original
2+++ rewritten
3@@ -1,11 +1,11 @@
4-def _broadcast_shapes(*_shapes):
5+def patched__broadcast_shapes(*_shapes):
6+ """Patches ``torch._refs._broadcast_shapes``."""
7+ from functools import reduce
8+ from torch._prims_common import IntLike
9 from torch.fx.experimental.symbolic_shapes import (
10 guard_or_false,
11 is_nested_int,
12- size_hint,
13 )
14-
15- backed_so = torch.fx.experimental._config.backed_size_oblivious
16
17 shapes = tuple(
18 (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
19@@ -18,17 +18,15 @@
20 for shape in shapes:
21 if not isinstance(shape, Sequence):
22 raise RuntimeError(
23- "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
24+ "Input shapes should be of type ints, a tuple of ints, "
25+ "or a list of ints, got ",
26 shape,
27 )
28
29 # Computes common shape
30- common_shape: list[Union[int, torch.SymInt]] = [
31- 1,
32- ] * reduce(max, (len(shape) for shape in shapes))
33- for arg_idx, shape in enumerate(shapes):
34+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
35+ for _arg_idx, shape in enumerate(shapes):
36 for idx in range(-1, -1 - len(shape), -1):
37- # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
38 if is_nested_int(shape[idx]):
39 # Broadcasting is allowed for (j0, 1) or (j0, j0);
40 # not (j0, j1), (j0, 5), etc.
41@@ -37,40 +35,17 @@
42 ):
43 continue
44 else:
45- # When backed size oblivious is used, we specialize for broadcasting
46- # if its the only way to compile the example input.
47- # i.e: s0:1, s1:1 ==>
48- # assert s0==s1, no specialization on ==1 or !=1.
49- # The non-broadcast path is picked
50- # s0:1, s1:4 ==>
51- # specialize(s0) to be 1.
52- # s0:4, s1:1 ==>
53- # specialize(s1) to be 1.
54- if backed_so:
55- a = size_hint(shape[idx], allow_none=True)
56- b = size_hint(common_shape[idx], allow_none=True)
57- if a == 1 and b != 1:
58- torch._check(shape[idx] == 1)
59- if b == 1 and a != 1:
60- torch._check(common_shape[idx] == 1)
61 if guard_or_false(shape[idx] == common_shape[idx]):
62 continue
63-
64- if guard_or_false(common_shape[idx] == 1):
65+ # PATCHED: two cases, if == for sure, no broadcast,
66+ # otherwise maybe broadcast with max(dimensions)
67+ if guard_or_false(common_shape[idx] != 1):
68+ pass
69+ elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
70 if shape[idx] < 0:
71 raise ValueError("Attempting to broadcast a dimension with negative length!")
72 common_shape[idx] = shape[idx]
73-
74- if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
75- # broadcast case .
76- continue
77 else:
78- # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
79- torch._check(
80- common_shape[idx] == shape[idx],
81- lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
82- f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
83- f"should be broadcastable to {common_shape}",
84- )
85+ common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
86
87 return common_shape
torch: _constrain_user_specified_dimhint_range -> patched__constrain_user_specified_dimhint_range¶
1--- original
2+++ rewritten
3@@ -1,28 +1,31 @@
4-def _constrain_user_specified_dimhint_range(
5+def patched__constrain_user_specified_dimhint_range(
6 symint: torch.SymInt,
7 hint: int,
8- dim: _DimHint,
9+ dim: "_DimHint", # noqa: F821
10 range_constraints,
11 shape_env,
12- keypath: KeyPath,
13+ keypath: "KeyPath", # noqa: F821
14 i: Optional[int] = None,
15 ) -> Optional[str]:
16+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
17+ from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
18+
19 trace_vr = (
20 range_constraints[symint.node.expr]
21 if not is_int(symint)
22 else ValueRanges(int(symint), int(symint))
23 )
24-
25 # warn on 0/1 specialization for Dim.AUTO; not an actual error
26- if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
27- pathstr = f"inputs{pytree.keystr(keypath)}"
28- if i is not None:
29- pathstr += f".shape[{i}]"
30- msg = (
31- f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
32- + f"with a sample input with hint = {hint}."
33- )
34- log.warning(msg)
35+ # PATCHED: remove logging
36+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
37+ # pathstr = f"inputs{pytree.keystr(keypath)}"
38+ # if i is not None:
39+ # pathstr += f".shape[{i}]"
40+ # msg = (
41+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
42+ # f"with a sample input with hint = {hint}."
43+ # )
44+ # log.warning(msg)
45
46 try:
47 user_vr = ValueRanges(
48@@ -38,32 +41,40 @@
49
50 # check for Dim.DYNAMIC specializations; special case error message on 0/1
51 if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
52- path = f"inputs{pytree.keystr(keypath)}"
53+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
54 if i is not None:
55 path += f".shape[{i}]"
56 if (
57 trace_vr.is_singleton()
58 and hint in (0, 1)
59- and not torch.fx.experimental._config.backed_size_oblivious
60+ # PATCHED: line removed
61+ # and not torch.fx.experimental._config.backed_size_oblivious
62 ):
63- msg = (
64- f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
65- f"but export 0/1 specialized due to hint of {hint} for dimension {path}."
66- )
67+ return None
68+ # PATCHED: line removed
69+ # msg = (
70+ # f"- Received user-specified dim hint "
71+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
72+ # f"but export 0/1 specialized due to hint of "
73+ # f"{hint} for dimension {path}."
74+ # )
75 else:
76 msg = (
77- f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
78- f"but tracing inferred a static shape of {out_vr.lower} for dimension {path}."
79+ f"- Received user-specified dim hint "
80+ f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
81+ f"but tracing inferred a static shape of "
82+ f"{out_vr.lower} for dimension {path}."
83 )
84 return msg
85
86 except torch.utils._sympy.value_ranges.ValueRangeError:
87- path = f"inputs{pytree.keystr(keypath)}"
88+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
89 if i is not None:
90 path += f".shape[{i}]"
91 msg = (
92 f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
93- f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], "
94+ f"conflicting with the inferred min/max range of "
95+ f"[{trace_vr.lower}, {trace_vr.upper}], "
96 f"for {path}."
97 )
98 return msg
torch: _broadcast_in_dim_meta -> patched__broadcast_in_dim_meta¶
1--- original
2+++ rewritten
3@@ -1,6 +1,9 @@
4-def _broadcast_in_dim_meta(
5- a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
6+def patched__broadcast_in_dim_meta(
7+ a: torch._prims_common.TensorLikeType,
8+ shape: torch._prims_common.ShapeType,
9+ broadcast_dimensions: Sequence[int],
10 ):
11+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
12 from torch.fx.experimental.symbolic_shapes import (
13 guard_or_false,
14 guard_or_true,
15@@ -8,7 +11,7 @@
16 )
17
18 # Type checks
19- assert isinstance(a, TensorLike)
20+ assert isinstance(a, torch._prims_common.TensorLike)
21 assert isinstance(shape, Sequence)
22 assert isinstance(broadcast_dimensions, Sequence)
23
24@@ -22,7 +25,7 @@
25 # (no relative reordering of dims) of integers and
26 # each dimension must be within the new shape
27 def _greater_than_reduce(acc, x):
28- assert isinstance(x, Dim)
29+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
30 assert x > acc
31 assert x < len(shape)
32
33@@ -34,7 +37,9 @@
34 for idx, new_idx in enumerate(broadcast_dimensions):
35 torch._check(
36 sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
37- lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
38+ lambda idx=idx, new_idx=new_idx: (
39+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
40+ ),
41 )
42
43 new_strides = []
44@@ -48,10 +53,26 @@
45 new_strides.append(a.stride()[original_idx])
46 else:
47 new_strides.append(0)
48+ # PATCHED: disabled this check
49+ elif guard_or_false(a.shape[original_idx] != 1):
50+ new_strides.append(a.stride()[original_idx])
51 else:
52+ # This checks generates the following issue:
53+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
54+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
55+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
56+ # original_idx=1
57 torch._check(
58 a.shape[original_idx] == shape[idx],
59- lambda: f"non-broadcasting semantics require {a.shape[original_idx]} == {shape[idx]}",
60+ lambda idx=idx, original_idx=original_idx: (
61+ f"non-broadcasting semantics require "
62+ f"{a.shape[original_idx]} == {shape[idx]}, "
63+ f"{guard_or_false(a.shape[idx] != 1)}, "
64+ f"guard_or_false(a.shape[idx]==1)="
65+ f"{guard_or_false(a.shape[idx] == 1)}, "
66+ f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
67+ f"shape={shape}, original_idx={original_idx}"
68+ ),
69 )
70 new_strides.append(a.stride()[original_idx])
71 original_idx = original_idx + 1
torch: _maybe_broadcast -> patched__maybe_broadcast¶
1--- original
2+++ rewritten
3@@ -1,6 +1,9 @@
4-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
5+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
6+ """Patches ``torch._refs._maybe_broadcast``."""
7+ from torch._prims_common import ShapeType, TensorLike, Number
8+
9 # Computes common shape
10- common_shape = _broadcast_shapes(
11+ common_shape = patched__broadcast_shapes(
12 *(t.shape if isinstance(t, TensorLike) else None for t in args)
13 )
14
15@@ -29,10 +32,15 @@
16 return True
17
18 # u0==u1 assume the same, no broadcasting!
19- torch._check(
20- x == y,
21- lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
22- )
23+ # PATCHED: avoid errors
24+ return True # guard_or_true(x != y)
25+ # torch._check(
26+ # x == y,
27+ # lambda x=x, y=y: (
28+ # f"sizes assumed to be the same due to unbacked "
29+ # f"broadcasting semantics x={x!r}, y={y!r}"
30+ # ),
31+ # )
32
33 return False
34
35@@ -42,7 +50,7 @@
36 elif isinstance(x, Number):
37 return x
38 elif isinstance(x, TensorLike):
39- if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
40+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
41 return x
42
43 if should_expand(x.shape, common_shape):
44@@ -50,6 +58,6 @@
45
46 return x
47 else:
48- raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
49+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
50
51 return tuple(__maybe_broadcast(x, common_shape) for x in args)
torch: ShapeEnv._evaluate_expr -> patched_ShapeEnv._evaluate_expr¶
1--- original
2+++ rewritten
3@@ -1,14 +1,24 @@
4 def _evaluate_expr(
5 self,
6- orig_expr: sympy.Basic,
7+ orig_expr: "sympy.Basic", # noqa: F821
8 hint: Optional[Union[bool, int, float]] = None,
9 fx_node: Optional[torch.fx.Node] = None,
10 size_oblivious: bool = False,
11 fallback_value: Optional[bool] = None,
12 *,
13 forcing_spec: bool = False,
14-) -> sympy.Basic:
15+) -> "sympy.Basic": # noqa: F821
16 # TODO: split conjunctions and evaluate them separately
17+ import sympy
18+ from torch.fx.experimental import _config as config
19+ from torch.fx.experimental.symbolic_shapes import (
20+ SympyBoolean,
21+ log,
22+ SymT,
23+ symbol_is_type,
24+ )
25+ from torch._guards import ShapeGuard
26+
27 if isinstance(
28 orig_expr,
29 (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
30@@ -118,7 +128,8 @@
31 self._log_suppressed_dde(orig_expr, fallback_value)
32 return fallback_value
33
34- # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
35+ # oblivious_var_to_val will be defined iff we have sizes
36+ # with DimDynamic.OBLIVIOUS_SIZE type.
37 # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
38 if (
39 self.oblivious_var_to_val
40@@ -136,15 +147,17 @@
41 log.info(
42 "oblivious_size %s -> %s (passed counterfactual)",
43 orig_expr,
44+ # pyrefly: ignore # unbound-name
45 correct_hint,
46 )
47-
48+ # pyrefly: ignore # unbound-name
49 concrete_val = correct_hint
50 # NB: do NOT transmute into runtime assert
51 ok = True
52
53 # unbacked_var_to_val is not None iff propagate_real_tensors is on.
54- # if propagate_real_tensors is on, we check the example values to generate (unsound_result)
55+ # if propagate_real_tensors is on, we check the example values
56+ # to generate (unsound_result)
57 # and if they pass we add a runtime assertions and continue.
58 if (
59 not ok
60@@ -155,25 +168,29 @@
61 )
62 ).free_symbols
63 ):
64+ # pyrefly: ignore # unbound-name
65 self._log_real_tensor_propagation(orig_expr, unsound_result)
66 transmute_into_runtime_assert = True
67-
68+ # pyrefly: ignore # unbound-name
69 concrete_val = unsound_result
70 ok = True
71
72- # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
73+ # Check if this is coming from a python assert statement,
74+ # if so, convert it to a runtime assertion
75 # instead of failing.
76 if not ok and self.trace_asserts and self._is_python_assert():
77 concrete_val = sympy.true
78 transmute_into_runtime_assert = True
79 ok = True
80
81- if not ok:
82- raise self._make_data_dependent_error(
83- expr.xreplace(self.var_to_val),
84- expr,
85- expr_sym_node_id=self._expr_sym_node_id,
86- )
87+ # PATCHED: ok -> True
88+ ok = True
89+ # if not ok:
90+ # raise self._make_data_dependent_error(
91+ # expr.xreplace(self.var_to_val),
92+ # expr,
93+ # expr_sym_node_id=self._expr_sym_node_id,
94+ # )
95 else:
96 expr = new_expr
patch_transformers: dynamic_rope_update -> patched_dynamic_rope_update¶
1--- original
2+++ rewritten
3@@ -1,99 +1,193 @@
4-def dynamic_rope_update(rope_forward):
5- """
6- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
7- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
8+def patched_dynamic_rope_update(rope_forward):
9+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
10
11- Args:
12- rope_forward (Callable):
13- The forward pass of the RoPE implementation.
14+ ``rope_type`` is determined in the constructor of class
15+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
16
17- Returns:
18- The decorated forward pass.
19+ .. code-block:: python
20+
21+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
22+ self.rope_type = config.rope_scaling.get(
23+ "rope_type", config.rope_scaling.get("type"))
24+ else:
25+ self.rope_type = "default"
26+
27+ The original code of the patched function:
28+
29+ .. code-block:: python
30+
31+ def dynamic_rope_update(rope_forward):
32+ def longrope_frequency_update(self, position_ids, device):
33+ seq_len = torch.max(position_ids) + 1
34+ if hasattr(self.config, "original_max_position_embeddings"):
35+ original_max_position_embeddings =
36+ self.config.original_max_position_embeddings
37+ else:
38+ original_max_position_embeddings =
39+ self.config.max_position_embeddings
40+ if seq_len > original_max_position_embeddings:
41+ if not hasattr(self, "long_inv_freq"):
42+ self.long_inv_freq, _ = self.rope_init_fn(
43+ self.config, device, seq_len=original_max_position_embeddings + 1
44+ )
45+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
46+ else:
47+ self.original_inv_freq = self.original_inv_freq.to(device)
48+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
49+
50+ def dynamic_frequency_update(self, position_ids, device):
51+ seq_len = torch.max(position_ids) + 1
52+ if seq_len > self.max_seq_len_cached: # growth
53+ inv_freq, self.attention_scaling = self.rope_init_fn(
54+ self.config, device, seq_len=seq_len)
55+ self.register_buffer("inv_freq", inv_freq, persistent=False)
56+ self.max_seq_len_cached = seq_len
57+
58+ if seq_len < self.original_max_seq_len and
59+ self.max_seq_len_cached > self.original_max_seq_len:
60+ self.original_inv_freq = self.original_inv_freq.to(device)
61+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
62+ self.max_seq_len_cached = self.original_max_seq_len
63+
64+ @wraps(rope_forward)
65+ def wrapper(self, x, position_ids):
66+ if "dynamic" in self.rope_type:
67+ dynamic_frequency_update(self, position_ids, device=x.device)
68+ elif self.rope_type == "longrope":
69+ longrope_frequency_update(self, position_ids, device=x.device)
70+ return rope_forward(self, x, position_ids)
71+
72+ return wrapper
73+
74 """
75
76 def longrope_frequency_update(self, position_ids, device, layer_type=None):
77- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
78+ # It is no use to patch the function after the model is created
79+ # as rope_init_fn is an attribute set to one function when the model
80+ # is created and when no patch is applied yet.
81+ # So we select the patched version here.
82+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
83 seq_len = torch.max(position_ids) + 1
84- original_max_position_embeddings = getattr(
85- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
86- )
87+ if hasattr(self.config, "original_max_position_embeddings"):
88+ original_max_position_embeddings = self.config.original_max_position_embeddings
89+ else:
90+ original_max_position_embeddings = self.config.max_position_embeddings
91+
92 if layer_type is None:
93- rope_type = self.rope_type
94+ # rope_type = self.rope_type
95 original_inv_freq = self.original_inv_freq
96 prefix = ""
97 else:
98- rope_type = self.rope_type[layer_type]
99+ # rope_type = self.rope_type[layer_type]
100 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
101 prefix = f"{layer_type}_"
102
103- if seq_len > original_max_position_embeddings:
104- if not hasattr(self, f"{layer_type}_long_inv_freq"):
105- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
106- long_inv_freq, _ = rope_init_fn(
107- self.config,
108- device,
109- seq_len=original_max_position_embeddings + 1,
110- layer_type=layer_type,
111- )
112- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
113- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
114- else:
115- # This .to() is needed if the model has been moved to a device after being initialized (because
116- # the buffer is automatically moved, but not the original copy)
117- original_inv_freq = original_inv_freq.to(device)
118- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
119- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
120+ # At export time, seq_len is unknown.
121+ long_inv_freq, _ = rope_init_fn(
122+ self.config, device, seq_len=original_max_position_embeddings + 1
123+ )
124+ original_inv_freq = self.original_inv_freq.to(device)
125+
126+ # PATCHED: uses torch.cond instead of a test
127+ cond = (seq_len > original_max_position_embeddings).item()
128+ inv_freq = torch.cond(
129+ cond,
130+ (lambda x, y: x.clone()),
131+ (lambda x, y: y.clone()),
132+ [long_inv_freq, original_inv_freq],
133+ )
134+ setattr(self, f"{prefix}inv_freq", inv_freq)
135+ # if seq_len > original_max_position_embeddings:
136+ # self.inv_freq = self.long_inv_freq
137+ # else:
138+ # self.inv_freq = self.original_inv_freq
139
140 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
141- """
142- dynamic RoPE layers should recompute `inv_freq` in the following situations:
143- 1 - growing beyond the cached sequence length (allow scaling)
144- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
145- """
146+ # constructor:
147+ # - self.max_seq_len_cached = config.max_position_embeddings
148+ # - self.original_max_seq_len = config.max_position_embeddings
149+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
150+
151+ # It is no use to patch the function after the model is created
152+ # as rope_init_fn is an attribute set to one function when the model
153+ # is created and when no patch is applied yet.
154+ # So we select the patched version here.
155+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
156+
157+ # This behaviour is difficult to translate.
158+ # The sequence always grows.
159+ # The test should always True.
160+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
161+ #
162+ # if seq_len > self.max_seq_len_cached: # growth
163+ # inv_freq, self.attention_scaling = self.rope_init_fn(
164+ # self.config, device, seq_len=seq_len
165+ # )
166+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
167+ # self.max_seq_len_cached = seq_len
168+ #
169+ # So we should not need what follows.
170+ #
171+ # cond = (seq_len > self.max_seq_len_cached).item()
172+ # self.attention_scaling = torch.cond(
173+ # cond,
174+ # (lambda x, y: x.clone()),
175+ # (lambda x, y: y.clone()),
176+ # [attention_scaling, self.attention_scaling],
177+ # )
178+
179 seq_len = torch.max(position_ids) + 1
180+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
181+
182 if layer_type is None:
183- rope_type = self.rope_type
184- max_seq_len_cached = self.max_seq_len_cached
185+ # rope_type = self.rope_type
186+ # max_seq_len_cached = self.max_seq_len_cached
187 original_inv_freq = self.original_inv_freq
188 prefix = ""
189 else:
190- rope_type = self.rope_type[layer_type]
191- max_seq_len_cached = getattr(
192- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
193- )
194+ # rope_type = self.rope_type[layer_type]
195+ # max_seq_len_cached = getattr(
196+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
197+ # )
198 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
199 prefix = f"{layer_type}_"
200
201- if seq_len > max_seq_len_cached: # growth
202- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
203- inv_freq, self.attention_scaling = rope_init_fn(
204- self.config,
205- device,
206- seq_len=seq_len,
207- layer_type=layer_type,
208- )
209- # TODO joao: may break with compilation
210- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
211- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
212+ # Second test to translate.
213+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
214+ # But in that case the following condition is a way to restore the original cache.
215
216- if (
217- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
218- ): # reset
219- # This .to() is needed if the model has been moved to a device after being initialized (because
220- # the buffer is automatically moved, but not the original copy)
221- original_inv_freq = original_inv_freq.to(device)
222- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
223- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
224- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
225+ # if (
226+ # seq_len < self.original_max_seq_len
227+ # and self.max_seq_len_cached > self.original_max_seq_len
228+ # ):
229+ # self.original_inv_freq = self.original_inv_freq.to(device)
230+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
231+ # self.max_seq_len_cached = self.original_max_seq_len
232+
233+ original_inv_freq = self.original_inv_freq.to(device)
234+ cond = (seq_len >= self.original_max_seq_len).item()
235+ # PATCHED: uses torch.cond instead of a test
236+ inv_freq = torch.cond(
237+ cond,
238+ (lambda x, y: x.clone()),
239+ (lambda x, y: y.clone()),
240+ [long_inv_freq, original_inv_freq],
241+ )
242+ setattr(self, f"{prefix}inv_freq", inv_freq)
243
244 @wraps(rope_forward)
245 def wrapper(self, x, position_ids, layer_type=None):
246- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
247- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
248- if "dynamic" in rope_type:
249- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
250- elif rope_type == "longrope":
251- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
252- return rope_forward(self, x, position_ids, **kwargs)
253+ if layer_type is None:
254+ if "dynamic" in self.rope_type:
255+ dynamic_frequency_update(self, position_ids, device=x.device)
256+ elif self.rope_type == "longrope":
257+ longrope_frequency_update(self, position_ids, device=x.device)
258+ return rope_forward(self, x, position_ids)
259+
260+ if "dynamic" in self.rope_type:
261+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
262+ elif self.rope_type == "longrope":
263+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
264+ return rope_forward(self, x, position_ids, layer_type=layer_type)
265
266 return wrapper
transformers: eager_mask -> patched_eager_mask¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def eager_mask(
5+def patched_eager_mask(
6 batch_size: int,
7 cache_position: torch.Tensor,
8 kv_length: int,
9@@ -6,41 +6,14 @@
10 mask_function: Callable = causal_mask_function,
11 attention_mask: Optional[torch.Tensor] = None,
12 dtype: torch.dtype = torch.float32,
13- allow_is_bidirectional_skip: bool = False,
14- use_vmap: bool = False,
15 **kwargs,
16 ) -> torch.Tensor:
17- """
18- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
19- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
20- it should not.
21-
22- Args:
23- batch_size (`int`):
24- The batch size of the input sequence.
25- cache_position (`torch.Tensor`):
26- A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
27- kv_length (`int`):
28- The size that the key and value states will have during the attention computation.
29- kv_offset (`int`, optional):
30- An optional offset to indicate at which first position the key and values states will refer to.
31- mask_function (`Callable`):
32- The mask factory function describing the mask pattern.
33- attention_mask (`torch.Tensor`, optional):
34- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
35- dtype (`torch.dtype`, optional):
36- The dtype to use for the mask. By default, `torch.float32`.
37- allow_is_bidirectional_skip (`bool`, optional):
38- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
39- i.e. full attention without any padding. Default to `False`.
40- use_vmap (`bool`, optional):
41- Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
42- index-based (for the cost of speed performance). By default `False`.
43- """
44+ """manual patch for function ``transformers.masking_utils.eager_mask``."""
45 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
46 _ = kwargs.pop("allow_is_causal_skip", None)
47- _ = kwargs.pop("allow_torch_fix", None)
48- mask = sdpa_mask(
49+ _ = kwargs.pop("allow_is_bidirectional_skip", None)
50+ # PATCHED: this line called the patched version of sdpa_mask
51+ mask = patched_sdpa_mask_recent_torch(
52 batch_size=batch_size,
53 cache_position=cache_position,
54 kv_length=kv_length,
55@@ -48,14 +21,15 @@
56 mask_function=mask_function,
57 attention_mask=attention_mask,
58 allow_is_causal_skip=False,
59- allow_is_bidirectional_skip=allow_is_bidirectional_skip,
60+ allow_is_bidirectional_skip=False,
61 allow_torch_fix=False,
62- use_vmap=use_vmap,
63 **kwargs,
64 )
65- # only bidirectional masks can be skipped, otherwise we convert bool -> float
66- if mask is not None:
67- min_dtype = torch.finfo(dtype).min
68- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
69- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
70+ min_dtype = torch.finfo(dtype).min
71+ # PATCHED: the following line
72+ # we need 0s where the tokens should be taken into account,
73+ # and -inf otherwise (mask is already of boolean type)
74+ # mask =
75+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
76+ mask = (~mask).to(dtype) * min_dtype
77 return mask
transformers: sdpa_attention_forward -> patched_sdpa_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,65 +1,142 @@
4-def sdpa_attention_forward(
5+def patched_sdpa_attention_forward(
6 module: torch.nn.Module,
7 query: torch.Tensor,
8 key: torch.Tensor,
9 value: torch.Tensor,
10- attention_mask: torch.Tensor | None,
11+ attention_mask: Optional[torch.Tensor],
12 dropout: float = 0.0,
13- scaling: float | None = None,
14- is_causal: bool | None = None,
15+ scaling: Optional[float] = None,
16+ is_causal: Optional[bool] = None,
17 **kwargs,
18 ) -> tuple[torch.Tensor, None]:
19- if kwargs.get("output_attentions", False):
20- logger.warning_once(
21- "`sdpa` attention does not support `output_attentions=True`."
22- " Please set your attention to `eager` if you want any of these features."
23- )
24+ """
25+ manual patch for function
26+ ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
27+ """
28+ assert not kwargs.get("output_attentions", False), (
29+ "`sdpa` attention does not support `output_attentions=True`."
30+ " Please set your attention to `eager` if you want any of these features."
31+ )
32+ torch._check(
33+ query.shape[0] == key.shape[0] or query.shape[0] == 1,
34+ lambda: (
35+ f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
36+ f"value: {value.shape}"
37+ ),
38+ )
39+ torch._check(
40+ key.shape[0] == value.shape[0] or key.shape[0] == 1,
41+ lambda: (
42+ f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
43+ f"value: {value.shape}"
44+ ),
45+ )
46+
47 sdpa_kwargs = {}
48 if hasattr(module, "num_key_value_groups"):
49- if not use_gqa_in_sdpa(attention_mask, key):
50- key = repeat_kv(key, module.num_key_value_groups)
51- value = repeat_kv(value, module.num_key_value_groups)
52+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
53+ key = transformers.integrations.sdpa_attention.repeat_kv(
54+ key, module.num_key_value_groups
55+ )
56+ value = transformers.integrations.sdpa_attention.repeat_kv(
57+ value, module.num_key_value_groups
58+ )
59 else:
60 sdpa_kwargs = {"enable_gqa": True}
61
62- # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
63- is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
64+ if attention_mask is not None and attention_mask.ndim == 4:
65+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
66
67- # SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
68- # - Not in decoding phase (otherwise we want full attention on the single query token)
69- # - Attention mask is not to be provided (even if it is a causal pattern)
70- # - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
71- #
72- # Quirks on the conditionals:
73- # - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
74- # full graph options. Otherwise, dynamic shapes are prevented from compiling.
75- # - It is important to check first for the shape, otherwise compile will fail with
76- # `argument 'is_causal' must be bool, not SymBool`.
77- is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
78+ torch._check(
79+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
80+ lambda: "Attention mask shape incompatible with key shape.",
81+ )
82
83- # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
84- # We convert it to a bool for the SDPA kernel that only accepts bools.
85- if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
86- is_causal = is_causal.item()
87+ if patch_sdpa_is_causal:
88+ # transformers>=4.55
89+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
90
91- # When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
92- # and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
93- # This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
94- if _is_torch_npu_available:
95- if attention_mask is not None and attention_mask.dtype != torch.bool:
96- # Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
97- attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
98+ # PATCHED: remove the test query.shape[2] > 1
99+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
100+ # and we split the test to keep the minimum in torch.cond
101+ is_causal = attention_mask is None and is_causal
102
103- attn_output = torch.nn.functional.scaled_dot_product_attention(
104- query,
105- key,
106- value,
107- attn_mask=attention_mask,
108- dropout_p=dropout,
109- scale=scaling,
110- is_causal=is_causal,
111- **sdpa_kwargs,
112+ if not is_causal:
113+ torch._check(query.shape[0] > 0)
114+ torch._check(query.shape[1] > 0)
115+ torch._check(query.shape[2] > 0)
116+ torch._check(query.shape[3] > 0)
117+ torch._check(key.shape[0] > 0)
118+ torch._check(key.shape[1] > 0)
119+ torch._check(key.shape[2] > 0)
120+ torch._check(key.shape[3] > 0)
121+ torch._check(value.shape[0] > 0)
122+ torch._check(value.shape[1] > 0)
123+ torch._check(value.shape[2] > 0)
124+ torch._check(value.shape[3] > 0)
125+
126+ return (
127+ torch.nn.functional.scaled_dot_product_attention(
128+ query,
129+ key,
130+ value,
131+ attn_mask=attention_mask,
132+ dropout_p=dropout,
133+ scale=scaling,
134+ is_causal=is_causal,
135+ **sdpa_kwargs,
136+ )
137+ .transpose(1, 2)
138+ .contiguous(),
139+ None,
140+ )
141+ else:
142+ # transformers<4.55
143+ if is_causal is None and attention_mask is not None:
144+ is_causal = False
145+ if is_causal is not None:
146+ return (
147+ torch.nn.functional.scaled_dot_product_attention(
148+ query,
149+ key,
150+ value,
151+ attn_mask=attention_mask,
152+ dropout_p=dropout,
153+ scale=scaling,
154+ is_causal=is_causal,
155+ **sdpa_kwargs,
156+ )
157+ .transpose(1, 2)
158+ .contiguous(),
159+ None,
160+ )
161+
162+ # To avoid the following errors:
163+ # is_causal=query.shape[2] > 1
164+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
165+ # is_causal=torch.tensor(query.shape[2] > 1)
166+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
167+ attn_output = torch.cond(
168+ query.shape[2] > 1, # distinction between prefill and decoding steps
169+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
170+ query,
171+ key,
172+ value,
173+ dropout_p=dropout,
174+ scale=scaling,
175+ is_causal=True,
176+ **sdpa_kwargs,
177+ ).contiguous(),
178+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
179+ query,
180+ key,
181+ value,
182+ dropout_p=dropout,
183+ scale=scaling,
184+ is_causal=False,
185+ **sdpa_kwargs,
186+ ).contiguous(),
187+ [query, key, value],
188 )
189 attn_output = attn_output.transpose(1, 2).contiguous()
190-
191 return attn_output, None
auto/patch_transformers: DynamicLayer.lazy_initialization -> patched_DynamicLayer.lazy_initialization¶
1--- original
2+++ rewritten
3@@ -1,5 +1,9 @@
4 def lazy_initialization(self, key_states: torch.Tensor):
5 self.dtype, self.device = key_states.dtype, key_states.device
6- self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
7- self.values = torch.tensor([], dtype=self.dtype, device=self.device)
8- self.is_initialized = True
9+ new_shape = list(key_states.shape)
10+ new_shape[-2] = 0
11+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
12+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
13+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
14+ if patch_is_initialized:
15+ self.is_initialized = True
auto/patch_transformers: Gemma2RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: Gemma3Model.get_placeholder_mask -> patched_Gemma3Model.get_placeholder_mask¶
1--- original
2+++ rewritten
3@@ -4,14 +4,12 @@
4 inputs_embeds: torch.FloatTensor,
5 image_features: torch.FloatTensor,
6 ):
7- """
8- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
9- equal to the length of multimodal features. If the lengths are different, an error is raised.
10- """
11 if input_ids is None:
12 special_image_mask = inputs_embeds == self.get_input_embeddings()(
13 torch.tensor(
14- self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
15+ self.config.image_token_id,
16+ dtype=torch.long,
17+ device=inputs_embeds.device,
18 )
19 )
20 special_image_mask = special_image_mask.all(-1)
21@@ -23,8 +21,14 @@
22 special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
23 )
24 n_image_features = image_features.shape[0] * image_features.shape[1]
25- if inputs_embeds[special_image_mask].numel() != image_features.numel():
26- raise ValueError(
27- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
28- )
29+ # PATCHED: torch._check
30+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
31+ # raise ValueError( ... )
32+ torch._check(
33+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
34+ lambda: (
35+ f"Image features and image tokens do not match: tokens: "
36+ f"{n_image_tokens}, features {n_image_features}"
37+ ),
38+ )
39 return special_image_mask
auto/patch_transformers: Gemma3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,13 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6+@patched_dynamic_rope_update
7 def forward(self, x, position_ids, layer_type=None):
8- inv_freq = getattr(self, f"{layer_type}_inv_freq")
9- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
10+ if layer_type is not None:
11+ # transformers>=5.0
12+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
13+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
14+ else:
15+ # transformers<5.0
16+ inv_freq = self.inv_freq
17+ attention_scaling = self.attention_scaling
18
19 inv_freq_expanded = (
20 inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21@@ -12,7 +17,7 @@
22 device_type = (
23 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
24 )
25- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
26+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
27 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
28 emb = torch.cat((freqs, freqs), dim=-1)
29 cos = emb.cos() * attention_scaling
30
31--- original
32+++ rewritten
33@@ -1,99 +1,193 @@
34-def dynamic_rope_update(rope_forward):
35- """
36- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
37- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
38+def patched_dynamic_rope_update(rope_forward):
39+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
40
41- Args:
42- rope_forward (Callable):
43- The forward pass of the RoPE implementation.
44+ ``rope_type`` is determined in the constructor of class
45+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
46
47- Returns:
48- The decorated forward pass.
49+ .. code-block:: python
50+
51+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
52+ self.rope_type = config.rope_scaling.get(
53+ "rope_type", config.rope_scaling.get("type"))
54+ else:
55+ self.rope_type = "default"
56+
57+ The original code of the patched function:
58+
59+ .. code-block:: python
60+
61+ def dynamic_rope_update(rope_forward):
62+ def longrope_frequency_update(self, position_ids, device):
63+ seq_len = torch.max(position_ids) + 1
64+ if hasattr(self.config, "original_max_position_embeddings"):
65+ original_max_position_embeddings =
66+ self.config.original_max_position_embeddings
67+ else:
68+ original_max_position_embeddings =
69+ self.config.max_position_embeddings
70+ if seq_len > original_max_position_embeddings:
71+ if not hasattr(self, "long_inv_freq"):
72+ self.long_inv_freq, _ = self.rope_init_fn(
73+ self.config, device, seq_len=original_max_position_embeddings + 1
74+ )
75+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
76+ else:
77+ self.original_inv_freq = self.original_inv_freq.to(device)
78+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
79+
80+ def dynamic_frequency_update(self, position_ids, device):
81+ seq_len = torch.max(position_ids) + 1
82+ if seq_len > self.max_seq_len_cached: # growth
83+ inv_freq, self.attention_scaling = self.rope_init_fn(
84+ self.config, device, seq_len=seq_len)
85+ self.register_buffer("inv_freq", inv_freq, persistent=False)
86+ self.max_seq_len_cached = seq_len
87+
88+ if seq_len < self.original_max_seq_len and
89+ self.max_seq_len_cached > self.original_max_seq_len:
90+ self.original_inv_freq = self.original_inv_freq.to(device)
91+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
92+ self.max_seq_len_cached = self.original_max_seq_len
93+
94+ @wraps(rope_forward)
95+ def wrapper(self, x, position_ids):
96+ if "dynamic" in self.rope_type:
97+ dynamic_frequency_update(self, position_ids, device=x.device)
98+ elif self.rope_type == "longrope":
99+ longrope_frequency_update(self, position_ids, device=x.device)
100+ return rope_forward(self, x, position_ids)
101+
102+ return wrapper
103+
104 """
105
106 def longrope_frequency_update(self, position_ids, device, layer_type=None):
107- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
108+ # It is no use to patch the function after the model is created
109+ # as rope_init_fn is an attribute set to one function when the model
110+ # is created and when no patch is applied yet.
111+ # So we select the patched version here.
112+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
113 seq_len = torch.max(position_ids) + 1
114- original_max_position_embeddings = getattr(
115- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
116- )
117+ if hasattr(self.config, "original_max_position_embeddings"):
118+ original_max_position_embeddings = self.config.original_max_position_embeddings
119+ else:
120+ original_max_position_embeddings = self.config.max_position_embeddings
121+
122 if layer_type is None:
123- rope_type = self.rope_type
124+ # rope_type = self.rope_type
125 original_inv_freq = self.original_inv_freq
126 prefix = ""
127 else:
128- rope_type = self.rope_type[layer_type]
129+ # rope_type = self.rope_type[layer_type]
130 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
131 prefix = f"{layer_type}_"
132
133- if seq_len > original_max_position_embeddings:
134- if not hasattr(self, f"{layer_type}_long_inv_freq"):
135- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
136- long_inv_freq, _ = rope_init_fn(
137- self.config,
138- device,
139- seq_len=original_max_position_embeddings + 1,
140- layer_type=layer_type,
141- )
142- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
143- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
144- else:
145- # This .to() is needed if the model has been moved to a device after being initialized (because
146- # the buffer is automatically moved, but not the original copy)
147- original_inv_freq = original_inv_freq.to(device)
148- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
149- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
150+ # At export time, seq_len is unknown.
151+ long_inv_freq, _ = rope_init_fn(
152+ self.config, device, seq_len=original_max_position_embeddings + 1
153+ )
154+ original_inv_freq = self.original_inv_freq.to(device)
155+
156+ # PATCHED: uses torch.cond instead of a test
157+ cond = (seq_len > original_max_position_embeddings).item()
158+ inv_freq = torch.cond(
159+ cond,
160+ (lambda x, y: x.clone()),
161+ (lambda x, y: y.clone()),
162+ [long_inv_freq, original_inv_freq],
163+ )
164+ setattr(self, f"{prefix}inv_freq", inv_freq)
165+ # if seq_len > original_max_position_embeddings:
166+ # self.inv_freq = self.long_inv_freq
167+ # else:
168+ # self.inv_freq = self.original_inv_freq
169
170 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
171- """
172- dynamic RoPE layers should recompute `inv_freq` in the following situations:
173- 1 - growing beyond the cached sequence length (allow scaling)
174- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
175- """
176+ # constructor:
177+ # - self.max_seq_len_cached = config.max_position_embeddings
178+ # - self.original_max_seq_len = config.max_position_embeddings
179+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
180+
181+ # It is no use to patch the function after the model is created
182+ # as rope_init_fn is an attribute set to one function when the model
183+ # is created and when no patch is applied yet.
184+ # So we select the patched version here.
185+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
186+
187+ # This behaviour is difficult to translate.
188+ # The sequence always grows.
189+ # The test should always True.
190+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
191+ #
192+ # if seq_len > self.max_seq_len_cached: # growth
193+ # inv_freq, self.attention_scaling = self.rope_init_fn(
194+ # self.config, device, seq_len=seq_len
195+ # )
196+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
197+ # self.max_seq_len_cached = seq_len
198+ #
199+ # So we should not need what follows.
200+ #
201+ # cond = (seq_len > self.max_seq_len_cached).item()
202+ # self.attention_scaling = torch.cond(
203+ # cond,
204+ # (lambda x, y: x.clone()),
205+ # (lambda x, y: y.clone()),
206+ # [attention_scaling, self.attention_scaling],
207+ # )
208+
209 seq_len = torch.max(position_ids) + 1
210+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
211+
212 if layer_type is None:
213- rope_type = self.rope_type
214- max_seq_len_cached = self.max_seq_len_cached
215+ # rope_type = self.rope_type
216+ # max_seq_len_cached = self.max_seq_len_cached
217 original_inv_freq = self.original_inv_freq
218 prefix = ""
219 else:
220- rope_type = self.rope_type[layer_type]
221- max_seq_len_cached = getattr(
222- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
223- )
224+ # rope_type = self.rope_type[layer_type]
225+ # max_seq_len_cached = getattr(
226+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227+ # )
228 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
229 prefix = f"{layer_type}_"
230
231- if seq_len > max_seq_len_cached: # growth
232- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
233- inv_freq, self.attention_scaling = rope_init_fn(
234- self.config,
235- device,
236- seq_len=seq_len,
237- layer_type=layer_type,
238- )
239- # TODO joao: may break with compilation
240- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
241- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
242+ # Second test to translate.
243+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
244+ # But in that case the following condition is a way to restore the original cache.
245
246- if (
247- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
248- ): # reset
249- # This .to() is needed if the model has been moved to a device after being initialized (because
250- # the buffer is automatically moved, but not the original copy)
251- original_inv_freq = original_inv_freq.to(device)
252- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
253- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
254- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
255+ # if (
256+ # seq_len < self.original_max_seq_len
257+ # and self.max_seq_len_cached > self.original_max_seq_len
258+ # ):
259+ # self.original_inv_freq = self.original_inv_freq.to(device)
260+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
261+ # self.max_seq_len_cached = self.original_max_seq_len
262+
263+ original_inv_freq = self.original_inv_freq.to(device)
264+ cond = (seq_len >= self.original_max_seq_len).item()
265+ # PATCHED: uses torch.cond instead of a test
266+ inv_freq = torch.cond(
267+ cond,
268+ (lambda x, y: x.clone()),
269+ (lambda x, y: y.clone()),
270+ [long_inv_freq, original_inv_freq],
271+ )
272+ setattr(self, f"{prefix}inv_freq", inv_freq)
273
274 @wraps(rope_forward)
275 def wrapper(self, x, position_ids, layer_type=None):
276- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
277- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
278- if "dynamic" in rope_type:
279- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
280- elif rope_type == "longrope":
281- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
282- return rope_forward(self, x, position_ids, **kwargs)
283+ if layer_type is None:
284+ if "dynamic" in self.rope_type:
285+ dynamic_frequency_update(self, position_ids, device=x.device)
286+ elif self.rope_type == "longrope":
287+ longrope_frequency_update(self, position_ids, device=x.device)
288+ return rope_forward(self, x, position_ids)
289+
290+ if "dynamic" in self.rope_type:
291+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
292+ elif self.rope_type == "longrope":
293+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
294+ return rope_forward(self, x, position_ids, layer_type=layer_type)
295
296 return wrapper
auto/patch_transformers: GemmaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation -> patched_GenerationMixin._cache_dependant_input_preparation¶
1--- original
2+++ rewritten
3@@ -1,25 +1,31 @@
4 def _cache_dependant_input_preparation(
5 self,
6 input_ids: torch.LongTensor,
7- inputs_embeds: torch.FloatTensor | None,
8- cache_position: torch.LongTensor | None,
9-) -> tuple[torch.FloatTensor, torch.LongTensor]:
10+ inputs_embeds: Optional[torch.FloatTensor],
11+ cache_position: Optional[torch.LongTensor],
12+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
13 """
14 Generic cache-dependent input preparation
15 The code is put in a separate function to allow granular unit testing
16 as it needs a different implementation to be exportable.
17
18- If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
19- - Exception 1: when passing input_embeds, input_ids may be missing entries
20- - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
21- - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
22- - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
23- generate the first token for each sequence. Later use the generated Input ids for continuation.
24+ If we have cache: let's slice `input_ids` through `cache_position`,
25+ to keep only the unprocessed tokens
26+ - Exception 1: when passing input_embeds,
27+ input_ids may be missing entries
28+ - Exception 2: some generation methods do special slicing of input_ids,
29+ so we don't need to do it here
30+ - Exception 3: with synced GPUs cache_position may go out of bounds,
31+ but we only want dummy token in that case.
32+ - Exception 4: If input_embeds are passed then slice it through
33+ `cache_position`, to keep only the unprocessed tokens and
34+ generate the first token for each sequence.
35+ Later use the generated Input ids for continuation.
36
37 The current implementation does not rely on ``self`` and could be
38 a class method. It is left as a standard method to be easily rewritten.
39 """
40- if is_torchdynamo_exporting():
41+ if _is_torchdynamo_exporting():
42 return self._cache_dependant_input_preparation_exporting(
43 input_ids, inputs_embeds, cache_position
44 )
auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation_exporting -> patched_GenerationMixin._cache_dependant_input_preparation_exporting¶
1--- original
2+++ rewritten
3@@ -1,9 +1,9 @@
4 def _cache_dependant_input_preparation_exporting(
5 self,
6 input_ids: torch.LongTensor,
7- inputs_embeds: torch.FloatTensor | None,
8- cache_position: torch.LongTensor | None,
9-) -> tuple[torch.FloatTensor, torch.LongTensor]:
10+ inputs_embeds: Optional[torch.FloatTensor],
11+ cache_position: Optional[torch.LongTensor],
12+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
13 """
14 This method implements method ``_cache_dependant_input_preparation``
15 with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
16@@ -21,7 +21,6 @@
17 # else:
18 # if input_ids.shape[1] != cache_position.shape[0]:
19 # input_ids = input_ids[:, cache_position]
20- # We need to clone the outputs to avoid aliasing.
21 def branch_1(inputs_embeds, cache_position):
22 return inputs_embeds[:, -cache_position.shape[0] :].clone()
23
24@@ -49,7 +48,7 @@
25 torch.cond(
26 input_ids.shape[1] != cache_position.shape[0],
27 branch_3,
28- (lambda input_ids, cache_position: input_ids.clone()),
29+ (lambda input_ids, cache_position: input_ids),
30 [input_ids, cache_position],
31 )
32 ),
auto/patch_transformers: IdeficsAttention.forward -> patched_IdeficsAttention.forward¶
1--- original
2+++ rewritten
3@@ -4,10 +4,12 @@
4 key_value_states: Optional[torch.Tensor] = None,
5 attention_mask: Optional[torch.Tensor] = None,
6 position_ids: Optional[torch.LongTensor] = None,
7- past_key_values: Optional[Cache] = None,
8+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
9+ output_attentions: bool = False,
10+ use_cache: bool = False,
11 cache_position: Optional[torch.LongTensor] = None,
12- **kwargs: Unpack[TransformersKwargs],
13-) -> tuple[torch.Tensor, torch.Tensor]:
14+ **kwargs,
15+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
16 # if key_value_states are provided this layer is used as a cross-attention layer
17 is_cross_attention = self.is_cross_attention or key_value_states is not None
18
19@@ -43,20 +45,27 @@
20 )
21
22 kv_seq_len = key_states.shape[-2]
23- if past_key_values is not None:
24+ if past_key_value is not None:
25 kv_seq_len += cache_position[0]
26
27 if not is_cross_attention:
28- cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
29- query_states, key_states = apply_rotary_pos_emb(
30- query_states, key_states, cos, sin, position_ids
31+ rotary_length = torch.maximum(
32+ torch.tensor(kv_seq_len, dtype=torch.int64),
33+ torch.tensor(q_len, dtype=torch.int64),
34+ )
35+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
36+ query_states, key_states = (
37+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
38+ query_states, key_states, cos, sin, position_ids
39+ )
40 )
41 # [bsz, nh, t, hd]
42
43- if past_key_values is not None:
44- # sin and cos are specific to RoPE models; cache_position needed for the static cache
45+ if past_key_value is not None:
46+ # sin and cos are specific to RoPE models;
47+ # cache_position needed for the static cache
48 cache_kwargs = {"cache_position": cache_position}
49- key_states, value_states = past_key_values.update(
50+ key_states, value_states = past_key_value.update(
51 key_states, value_states, self.layer_idx, cache_kwargs
52 )
53
54@@ -64,10 +73,22 @@
55 query_states = self.q_layer_norm(query_states)
56 key_states = self.k_layer_norm(key_states)
57
58- attention_interface: Callable = eager_attention_forward
59+ attention_interface: Callable = (
60+ transformers.models.idefics.modeling_idefics.eager_attention_forward
61+ )
62
63 if self.config._attn_implementation != "eager":
64- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
65+ if self.config._attn_implementation == "sdpa" and output_attentions:
66+ transformers.models.idefics.modeling_idefics.logger.warning_once(
67+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
68+ "`output_attentions=True`. Falling back to "
69+ "eager attention. This warning can be removed using the argument "
70+ '`attn_implementation="eager"` when loading the model.'
71+ )
72+ else:
73+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
74+ self.config._attn_implementation
75+ ]
76
77 attn_output, attn_weights = attention_interface(
78 self,
79@@ -83,4 +104,9 @@
80 attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
81 attn_output = self.o_proj(attn_output)
82
83+ if output_attentions:
84+ attn_weights = None
85+
86+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
87+ return attn_output, attn_weights, past_key_value
88 return attn_output, attn_weights
auto/patch_transformers: IdeficsEmbedding.forward -> patched_IdeficsEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,9 +1,26 @@
4 def forward(self, x, seq_len=None):
5 # x: [bs, num_attention_heads, seq_len, head_size]
6- if seq_len > self.max_seq_len_cached:
7- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
8+ # if seq_len > self.max_seq_len_cached:
9+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
10
11- return (
12- self.cos_cached[:seq_len].to(dtype=x.dtype),
13- self.sin_cached[:seq_len].to(dtype=x.dtype),
14+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
15+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
16+ # freqs = torch.einsum("i,j->ij", t, inv_freq)
17+ freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
18+ emb = torch.cat((freqs, freqs), dim=-1)
19+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
20+
21+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
22+ torch._check(seq_len.item() <= cos_cached.shape[0])
23+ co = cos_cached[: seq_len.item()].detach().clone()
24+ torch._check(seq_len.item() <= sin_cached.shape[0])
25+ si = sin_cached[: seq_len.item()].detach().clone()
26+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
27+
28+ cos_cached, sin_cached = torch.cond(
29+ (seq_len > self.max_seq_len_cached).item(),
30+ _set_cos_sin_cache_then,
31+ _set_cos_sin_cache_else,
32+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
33 )
34+ return cos_cached, sin_cached
auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: MistralRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: MixtralRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: Phi3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: Phi4MultimodalRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: PhiRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: Qwen2_5_VLForConditionalGeneration.prepare_inputs_for_generation -> patched_Qwen2_5_VLForConditionalGeneration.prepare_inputs_for_generation¶
1--- original
2+++ rewritten
3@@ -14,9 +14,12 @@
4 second_per_grid_ts=None,
5 **kwargs,
6 ):
7- # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
8+ # Overwritten -- in specific circumstances we don't want to
9+ # forward image inputs to the model
10+ from transformers.generation import GenerationMixin
11
12- model_inputs = super().prepare_inputs_for_generation(
13+ model_inputs = GenerationMixin.prepare_inputs_for_generation(
14+ self,
15 input_ids,
16 past_key_values=past_key_values,
17 attention_mask=attention_mask,
18@@ -36,8 +39,8 @@
19 if position_ids is None:
20 # Calculate RoPE index once per generation in the pre-fill stage only.
21 # When compiling, we can't check tensor values thus we check only input length
22- # It is safe to assume that `length!=1` means we're in pre-fill because compiled
23- # models currently cannot do assisted decoding
24+ # It is safe to assume that `length!=1` means we're in pre-fill
25+ # because compiled models currently cannot do assisted decoding
26 if cache_position[0] == 0 or self.model.rope_deltas is None:
27 vision_positions, rope_deltas = self.model.get_rope_index(
28 model_inputs.get("input_ids", None),
29@@ -48,7 +51,7 @@
30 )
31 self.model.rope_deltas = rope_deltas
32 # then use the prev pre-calculated rope-deltas to get the correct position ids
33- elif "position_ids" in model_inputs:
34+ elif "position_ids" in model_inputs and model_inputs["position_ids"] is not None:
35 batch_size, seq_length = model_inputs["position_ids"].shape
36 device = model_inputs["position_ids"].device
37 position_ids = torch.arange(seq_length, device=device)
38@@ -58,7 +61,14 @@
39 vision_positions = position_ids + delta.expand_as(position_ids)
40
41 # Concatenate "text + vision" positions into [4, bs, seq-len]
42- text_positions = model_inputs["position_ids"][None, ...]
43+ if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
44+ text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
45+ None, None, :
46+ ]
47+ else:
48+ text_positions = model_inputs["position_ids"][None, ...]
49+ # text_positions = model_inputs["position_ids"][None, ...]
50+ assert vision_positions is not None, "vision_positions are missing"
51 model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
52
53 if cache_position[0] != 0:
auto/patch_transformers: Qwen2_5_VLModel.get_placeholder_mask -> patched_Qwen2_5_VLModel.get_placeholder_mask¶
1--- original
2+++ rewritten
3@@ -5,20 +5,20 @@
4 image_features: Optional[torch.FloatTensor] = None,
5 video_features: Optional[torch.FloatTensor] = None,
6 ):
7- """
8- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
9- equal to the length of multimodal features. If the lengths are different, an error is raised.
10- """
11 if input_ids is None:
12 special_image_mask = inputs_embeds == self.get_input_embeddings()(
13 torch.tensor(
14- self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
15+ self.config.image_token_id,
16+ dtype=torch.long,
17+ device=inputs_embeds.device,
18 )
19 )
20 special_image_mask = special_image_mask.all(-1)
21 special_video_mask = inputs_embeds == self.get_input_embeddings()(
22 torch.tensor(
23- self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device
24+ self.config.video_token_id,
25+ dtype=torch.long,
26+ device=inputs_embeds.device,
27 )
28 )
29 special_video_mask = special_video_mask.all(-1)
30@@ -26,28 +26,34 @@
31 special_image_mask = input_ids == self.config.image_token_id
32 special_video_mask = input_ids == self.config.video_token_id
33
34- n_image_tokens = special_image_mask.sum()
35 special_image_mask = (
36 special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
37 )
38- if (
39- image_features is not None
40- and inputs_embeds[special_image_mask].numel() != image_features.numel()
41- ):
42- raise ValueError(
43- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
44- )
45
46- n_video_tokens = special_video_mask.sum()
47+ # PATCHED: we should use torch._check
48+ # but this fails for compilation. It cannot be verified with FakeTensors
49+ # torch._check(
50+ # image_features is None
51+ # or inputs_embeds[special_image_mask].numel() == image_features.numel(),
52+ # lambda: (
53+ # f"Image features and image tokens do not match: tokens: "
54+ # f"{special_image_mask.sum()}, features {image_features.shape[0]}"
55+ # ),
56+ # )
57+
58 special_video_mask = (
59 special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
60 )
61- if (
62- video_features is not None
63- and inputs_embeds[special_video_mask].numel() != video_features.numel()
64- ):
65- raise ValueError(
66- f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
67- )
68+
69+ # PATCHED: we should use torch._check
70+ # but this fails for compilation. It cannot be verified with FakeTensors
71+ # torch._check(
72+ # video_features is None
73+ # or inputs_embeds[special_video_mask].numel() == video_features.numel(),
74+ # lambda: (
75+ # f"Videos features and video tokens do not match: tokens: "
76+ # f"{special_video_mask.sum()}, features {video_features.shape[0]}"
77+ # ),
78+ # )
79
80 return special_image_mask, special_video_mask
auto/patch_transformers: Qwen2_5_VLVisionAttention.forward -> patched_Qwen2_5_VLVisionAttention.forward¶
1--- original
2+++ rewritten
3@@ -7,24 +7,82 @@
4 **kwargs,
5 ) -> torch.Tensor:
6 seq_length = hidden_states.shape[0]
7- query_states, key_states, value_states = (
8- self.qkv(hidden_states)
9- .reshape(seq_length, 3, self.num_heads, -1)
10- .permute(1, 0, 2, 3)
11- .unbind(0)
12+ # PATCHED: avoid the use of unbind
13+ qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3)
14+
15+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
16+ cos, sin = position_embeddings
17+
18+ # This part should be moved into the loop
19+ # iteration to enable fusion inside the loop.
20+ query_states, key_states = (
21+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
22+ query_states, key_states, cos, sin
23+ )
24 )
25- cos, sin = position_embeddings
26- query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
27
28 query_states = query_states.transpose(0, 1).unsqueeze(0)
29 key_states = key_states.transpose(0, 1).unsqueeze(0)
30 value_states = value_states.transpose(0, 1).unsqueeze(0)
31
32- attention_interface: Callable = eager_attention_forward
33+ attention_interface: Callable = (
34+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
35+ )
36 if self.config._attn_implementation != "eager":
37- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
38+ # PATCHED
39+ # attention_interface = ALL_ATTENTION_FUNCTIONS[
40+ # self.config._attn_implementation]
41+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
42+ self.config._attn_implementation
43+ ]
44
45- if self.config._attn_implementation == "flash_attention_2":
46+ is_sdpa_or_eager = (
47+ attention_interface is transformers.integrations.sdpa_attention.sdpa_attention_forward
48+ or attention_interface is patched_sdpa_attention_forward
49+ or attention_interface
50+ is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
51+ )
52+ if is_sdpa_or_eager:
53+ attn_output = qwen_sdpa_attention_versatile(
54+ query_states,
55+ key_states,
56+ value_states,
57+ cu_seqlens,
58+ self.scaling,
59+ self.num_heads,
60+ )
61+ elif _is_torchdynamo_exporting():
62+ if self.config._attn_implementation == "flash_attention_2":
63+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
64+ attn_output = torch.onnx.ops.symbolic(
65+ "custom::qwen25_flash_attention",
66+ (
67+ query_states,
68+ key_states,
69+ value_states,
70+ cu_seqlens,
71+ cu_seqlens,
72+ max_seqlen,
73+ max_seqlen,
74+ torch.tensor(self.scaling, dtype=torch.float32),
75+ ),
76+ dtype=query_states.dtype,
77+ shape=(
78+ query_states.shape[0], # batch_size
79+ query_states.shape[2], # sequence_length (total patches)
80+ query_states.shape[1], # num_heads
81+ query_states.shape[3], # head_size
82+ ),
83+ version=1,
84+ )
85+ else:
86+ raise NotImplementedError(
87+ f"No corresponding export strategy for implementation "
88+ f"{self.config._attn_implementation!r}, "
89+ f"(use QWEN25ATTENTION to change it), and attention_interface="
90+ f"{attention_interface!r} (use sdpa)"
91+ )
92+ elif self.config._attn_implementation == "flash_attention_2":
93 # Flash Attention 2: Use cu_seqlens for variable length attention
94 max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
95 attn_output, _ = attention_interface(
96@@ -44,6 +102,7 @@
97 )
98 else:
99 # Other implementations: Process each chunk separately
100+ # = qwen_sdpa_attention
101 lengths = cu_seqlens[1:] - cu_seqlens[:-1]
102 splits = [
103 torch.split(tensor, lengths.tolist(), dim=2)
auto/patch_transformers: Qwen2_5_VisionTransformerPretrainedModel.get_window_index -> patched_Qwen2_5_VisionTransformerPretrainedModel.get_window_index¶
1--- original
2+++ rewritten
3@@ -1,10 +1,15 @@
4 def get_window_index(self, grid_thw):
5- window_index: list = []
6- cu_window_seqlens: list = [0]
7+ window_index: list = [] # type: ignore[annotation-unchecked]
8+ # PATCHED
9+ cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
10 window_index_id = 0
11 vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
12
13- for grid_t, grid_h, grid_w in grid_thw:
14+ for _thw in grid_thw:
15+ # PATCHED: avoid unbind
16+ grid_t = _thw[0]
17+ grid_h = _thw[1]
18+ grid_w = _thw[2]
19 llm_grid_h, llm_grid_w = (
20 grid_h // self.spatial_merge_size,
21 grid_w // self.spatial_merge_size,
22@@ -34,9 +39,11 @@
23 index_padded = index_padded.reshape(-1)
24 index_new = index_padded[index_padded != -100]
25 window_index.append(index_new + window_index_id)
26- cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
27- cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
28+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
29+ # PATCHED
30+ # cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
31+ cu_window_seqlens.append(cu_seqlens_tmp)
32 window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
33 window_index = torch.cat(window_index, dim=0)
34
35- return window_index, cu_window_seqlens
36+ return window_index, torch.cat(cu_window_seqlens, dim=0)
auto/patch_transformers: Qwen2_5_VisionTransformerPretrainedModel.forward -> patched_Qwen2_5_VisionTransformerPretrainedModel.forward¶
1--- original
2+++ rewritten
3@@ -12,11 +12,13 @@
4 hidden_states = self.patch_embed(hidden_states)
5 rotary_pos_emb = self.rot_pos_emb(grid_thw)
6 window_index, cu_window_seqlens = self.get_window_index(grid_thw)
7- cu_window_seqlens = torch.tensor(
8- cu_window_seqlens,
9- device=hidden_states.device,
10- dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
11- )
12+ # PATCHED
13+ # cu_window_seqlens = torch.tensor(
14+ # cu_window_seqlens,
15+ # device=hidden_states.device,
16+ # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
17+ # )
18+ cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
19 cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
20
21 seq_len, _ = hidden_states.size()
22@@ -37,8 +39,10 @@
23 dim=0,
24 # Select dtype based on the following factors:
25 # - FA2 requires that cu_seqlens_q must have dtype int32
26- # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
27- # See https://github.com/huggingface/transformers/pull/34852 for more information
28+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype
29+ # as grid_thw
30+ # See https://github.com/huggingface/transformers/pull/34852
31+ # for more information
32 dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
33 )
34 cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
35@@ -55,9 +59,10 @@
36 position_embeddings=position_embeddings,
37 **kwargs,
38 )
39+ if STOPAT is not None and layer_num > STOPAT:
40+ break
41
42 hidden_states = self.merger(hidden_states)
43 reverse_indices = torch.argsort(window_index)
44 hidden_states = hidden_states[reverse_indices, :]
45-
46 return hidden_states
auto/patch_transformers: Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb -> patched_Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb¶
1--- original
2+++ rewritten
3@@ -1,6 +1,10 @@
4 def rot_pos_emb(self, grid_thw):
5 pos_ids = []
6- for t, h, w in grid_thw:
7+ for thw_ in grid_thw:
8+ # PATCHED: avoid unbind
9+ t = thw_[0]
10+ h = thw_[1]
11+ w = thw_[2]
12 hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
13 hpos_ids = hpos_ids.reshape(
14 h // self.spatial_merge_size,
auto/patch_transformers: Qwen3MoeSparseMoeBlock.forward -> patched_Qwen3MoeSparseMoeBlock.forward¶
1--- original
2+++ rewritten
3@@ -1,6 +1,67 @@
4-def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
5+def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6+ """ """
7 batch_size, sequence_length, hidden_dim = hidden_states.shape
8- hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
9- _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
10- final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
11- return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
12+ hidden_states = hidden_states.view(-1, hidden_dim)
13+ # router_logits: (batch * sequence_length, n_experts)
14+ router_logits = self.gate(hidden_states)
15+
16+ routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
17+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
18+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
19+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
20+ # we cast back to the input dtype
21+ routing_weights = routing_weights.to(hidden_states.dtype)
22+
23+ final_hidden_states = torch.zeros(
24+ (batch_size * sequence_length, hidden_dim),
25+ dtype=hidden_states.dtype,
26+ device=hidden_states.device,
27+ )
28+
29+ # One hot encode the selected experts to create an expert mask
30+ # this will be used to easily index which expert is going to be sollicitated
31+ expert_mask = torch.nn.functional.one_hot(
32+ selected_experts, num_classes=self.num_experts
33+ ).permute(2, 1, 0)
34+
35+ # Loop over all available experts in the model
36+ # and perform the computation on each expert
37+ expert_sum = expert_mask.sum(dim=(-1, -2))
38+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
39+ # for expert_idx in expert_hit:
40+ for expert_idx in range(self.num_experts):
41+ # initial code has a squeeze but it is not possible to do that.
42+ # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
43+ expert_mask_idx = expert_mask[expert_idx]
44+ final_hidden_states = torch.cond(
45+ (expert_sum[expert_idx] > 0).item(),
46+ lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
47+ final_hidden_states,
48+ expert_mask,
49+ hidden_states,
50+ routing_weights,
51+ expert_idx=_i,
52+ ),
53+ lambda final_hidden_states, *args: final_hidden_states.clone(),
54+ [final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
55+ )
56+
57+ # if expert_sum[expert_idx] > 0:
58+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
59+
60+ # Index the correct hidden states and compute the expert hidden state for
61+ # the current expert. We need to make sure to multiply the output hidden
62+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
63+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
64+ # current_hidden_states = (
65+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
66+ # )
67+
68+ # However `index_add_` only support torch tensors for indexing so we'll use
69+ # the `top_x` tensor here.
70+ # final_hidden_states.index_add_(
71+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
72+ # )
73+
74+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
75+ return final_hidden_states, router_logits
auto/patch_transformers: ‘Qwen3MoeSparseMoeBlock_forward_expert_loop’ -> patched_Qwen3MoeSparseMoeBlock._forward_expert_loop¶
1def _forward_expert_loop(
2 self,
3 final_hidden_states,
4 expert_mask_idx,
5 hidden_states,
6 routing_weights,
7 expert_idx: int,
8):
9 # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
10 idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
11 hidden_dim = hidden_states.shape[-1]
12 current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
13 expert_current_state = self.experts[expert_idx](current_state)
14 current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
15 return final_hidden_states.index_add(0, top_x, current_hidden_states.to(hidden_states.dtype))
auto/patch_transformers: SamMaskDecoder.forward -> patched_SamMaskDecoder.forward¶
1--- original
2+++ rewritten
3@@ -5,6 +5,7 @@
4 sparse_prompt_embeddings: torch.Tensor,
5 dense_prompt_embeddings: torch.Tensor,
6 multimask_output: bool,
7+ output_attentions: Optional[bool] = None,
8 attention_similarity: Optional[torch.Tensor] = None,
9 target_embedding: Optional[torch.Tensor] = None,
10 ) -> tuple[torch.Tensor, torch.Tensor]:
11@@ -22,19 +23,31 @@
12 the embeddings of the mask inputs
13 multimask_output (bool):
14 Whether to return multiple masks or a single mask.
15+ output_attentions (bool, *optional*):
16+ Whether or not to return the attentions tensors of all attention layers.
17 """
18 batch_size, num_channels, height, width = image_embeddings.shape
19- point_batch_size = (
20- sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
21- )
22+ point_batch_size = sparse_prompt_embeddings.shape[1]
23 # Concatenate output tokens
24 output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
25 output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
26
27- if sparse_prompt_embeddings is not None:
28- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
29- else:
30- tokens = output_tokens
31+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
32+ # torch.any is needed to avoid data-dependent control flow
33+ # with sparse_prompt_embeddings.sum().item() != 0
34+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
35+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
36+
37+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
38+ return output_tokens.clone()
39+
40+ tokens = torch.cond(
41+ torch.any(sparse_prompt_embeddings != 0),
42+ sparse_prompt_embeddings_is_not_empty,
43+ sparse_prompt_embeddings_is_empty,
44+ [output_tokens, sparse_prompt_embeddings],
45+ )
46+
47 point_embeddings = tokens.to(self.iou_token.weight.dtype)
48
49 # Expand per-image data in batch direction to be per-point
50@@ -45,15 +58,21 @@
51 )
52
53 # Run the transformer, image_positional_embedding are consumed
54- point_embedding, image_embeddings = self.transformer(
55+ torch._check(point_embeddings.shape[0] != 0)
56+ torch._check(point_embeddings.shape[1] != 0)
57+ torch._check(point_embeddings.shape[2] != 0)
58+ torch._check(point_embeddings.shape[3] != 0)
59+ embeddings_attentions = self.transformer(
60 point_embeddings=point_embeddings,
61 image_embeddings=image_embeddings,
62 image_positional_embeddings=image_positional_embeddings,
63 attention_similarity=attention_similarity,
64 target_embedding=target_embedding,
65+ output_attentions=output_attentions,
66 )
67- iou_token_out = point_embedding[:, :, 0, :]
68- mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
69+ point_embedding, image_embeddings = embeddings_attentions[:2]
70+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
71+ mask_tokens_out = torch.narrow(point_embedding, dim=2, start=1, length=self.num_mask_tokens)
72
73 # Upscale mask embeddings and predict masks using the mask tokens
74 image_embeddings = image_embeddings.transpose(2, 3).reshape(
75@@ -88,4 +107,15 @@
76 mask_slice = slice(0, 1)
77 masks = masks[:, :, mask_slice, :, :]
78 iou_pred = iou_pred[:, :, mask_slice]
79- return masks, iou_pred
80+
81+ outputs = (masks, iou_pred)
82+
83+ if len(embeddings_attentions) == 2:
84+ # transformers==4.54
85+ return outputs
86+
87+ if output_attentions and len(embeddings_attentions) > 2:
88+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
89+ else:
90+ outputs = outputs + (None,) # noqa: RUF005
91+ return outputs
auto/patch_transformers: SmolLM3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,18 +1,26 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24 device_type = (
25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
26 )
27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
30 emb = torch.cat((freqs, freqs), dim=-1)
31- cos = emb.cos() * self.attention_scaling
32- sin = emb.sin() * self.attention_scaling
33+ cos = emb.cos() * attention_scaling
34+ sin = emb.sin() * attention_scaling
35
36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
37
38--- original
39+++ rewritten
40@@ -1,99 +1,193 @@
41-def dynamic_rope_update(rope_forward):
42- """
43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
45+def patched_dynamic_rope_update(rope_forward):
46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
47
48- Args:
49- rope_forward (Callable):
50- The forward pass of the RoPE implementation.
51+ ``rope_type`` is determined in the constructor of class
52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
53
54- Returns:
55- The decorated forward pass.
56+ .. code-block:: python
57+
58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
59+ self.rope_type = config.rope_scaling.get(
60+ "rope_type", config.rope_scaling.get("type"))
61+ else:
62+ self.rope_type = "default"
63+
64+ The original code of the patched function:
65+
66+ .. code-block:: python
67+
68+ def dynamic_rope_update(rope_forward):
69+ def longrope_frequency_update(self, position_ids, device):
70+ seq_len = torch.max(position_ids) + 1
71+ if hasattr(self.config, "original_max_position_embeddings"):
72+ original_max_position_embeddings =
73+ self.config.original_max_position_embeddings
74+ else:
75+ original_max_position_embeddings =
76+ self.config.max_position_embeddings
77+ if seq_len > original_max_position_embeddings:
78+ if not hasattr(self, "long_inv_freq"):
79+ self.long_inv_freq, _ = self.rope_init_fn(
80+ self.config, device, seq_len=original_max_position_embeddings + 1
81+ )
82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
83+ else:
84+ self.original_inv_freq = self.original_inv_freq.to(device)
85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
86+
87+ def dynamic_frequency_update(self, position_ids, device):
88+ seq_len = torch.max(position_ids) + 1
89+ if seq_len > self.max_seq_len_cached: # growth
90+ inv_freq, self.attention_scaling = self.rope_init_fn(
91+ self.config, device, seq_len=seq_len)
92+ self.register_buffer("inv_freq", inv_freq, persistent=False)
93+ self.max_seq_len_cached = seq_len
94+
95+ if seq_len < self.original_max_seq_len and
96+ self.max_seq_len_cached > self.original_max_seq_len:
97+ self.original_inv_freq = self.original_inv_freq.to(device)
98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
99+ self.max_seq_len_cached = self.original_max_seq_len
100+
101+ @wraps(rope_forward)
102+ def wrapper(self, x, position_ids):
103+ if "dynamic" in self.rope_type:
104+ dynamic_frequency_update(self, position_ids, device=x.device)
105+ elif self.rope_type == "longrope":
106+ longrope_frequency_update(self, position_ids, device=x.device)
107+ return rope_forward(self, x, position_ids)
108+
109+ return wrapper
110+
111 """
112
113 def longrope_frequency_update(self, position_ids, device, layer_type=None):
114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+ # It is no use to patch the function after the model is created
116+ # as rope_init_fn is an attribute set to one function when the model
117+ # is created and when no patch is applied yet.
118+ # So we select the patched version here.
119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120 seq_len = torch.max(position_ids) + 1
121- original_max_position_embeddings = getattr(
122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123- )
124+ if hasattr(self.config, "original_max_position_embeddings"):
125+ original_max_position_embeddings = self.config.original_max_position_embeddings
126+ else:
127+ original_max_position_embeddings = self.config.max_position_embeddings
128+
129 if layer_type is None:
130- rope_type = self.rope_type
131+ # rope_type = self.rope_type
132 original_inv_freq = self.original_inv_freq
133 prefix = ""
134 else:
135- rope_type = self.rope_type[layer_type]
136+ # rope_type = self.rope_type[layer_type]
137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138 prefix = f"{layer_type}_"
139
140- if seq_len > original_max_position_embeddings:
141- if not hasattr(self, f"{layer_type}_long_inv_freq"):
142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143- long_inv_freq, _ = rope_init_fn(
144- self.config,
145- device,
146- seq_len=original_max_position_embeddings + 1,
147- layer_type=layer_type,
148- )
149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151- else:
152- # This .to() is needed if the model has been moved to a device after being initialized (because
153- # the buffer is automatically moved, but not the original copy)
154- original_inv_freq = original_inv_freq.to(device)
155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+ # At export time, seq_len is unknown.
158+ long_inv_freq, _ = rope_init_fn(
159+ self.config, device, seq_len=original_max_position_embeddings + 1
160+ )
161+ original_inv_freq = self.original_inv_freq.to(device)
162+
163+ # PATCHED: uses torch.cond instead of a test
164+ cond = (seq_len > original_max_position_embeddings).item()
165+ inv_freq = torch.cond(
166+ cond,
167+ (lambda x, y: x.clone()),
168+ (lambda x, y: y.clone()),
169+ [long_inv_freq, original_inv_freq],
170+ )
171+ setattr(self, f"{prefix}inv_freq", inv_freq)
172+ # if seq_len > original_max_position_embeddings:
173+ # self.inv_freq = self.long_inv_freq
174+ # else:
175+ # self.inv_freq = self.original_inv_freq
176
177 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178- """
179- dynamic RoPE layers should recompute `inv_freq` in the following situations:
180- 1 - growing beyond the cached sequence length (allow scaling)
181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182- """
183+ # constructor:
184+ # - self.max_seq_len_cached = config.max_position_embeddings
185+ # - self.original_max_seq_len = config.max_position_embeddings
186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+ # It is no use to patch the function after the model is created
189+ # as rope_init_fn is an attribute set to one function when the model
190+ # is created and when no patch is applied yet.
191+ # So we select the patched version here.
192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+ # This behaviour is difficult to translate.
195+ # The sequence always grows.
196+ # The test should always True.
197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+ #
199+ # if seq_len > self.max_seq_len_cached: # growth
200+ # inv_freq, self.attention_scaling = self.rope_init_fn(
201+ # self.config, device, seq_len=seq_len
202+ # )
203+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
204+ # self.max_seq_len_cached = seq_len
205+ #
206+ # So we should not need what follows.
207+ #
208+ # cond = (seq_len > self.max_seq_len_cached).item()
209+ # self.attention_scaling = torch.cond(
210+ # cond,
211+ # (lambda x, y: x.clone()),
212+ # (lambda x, y: y.clone()),
213+ # [attention_scaling, self.attention_scaling],
214+ # )
215+
216 seq_len = torch.max(position_ids) + 1
217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219 if layer_type is None:
220- rope_type = self.rope_type
221- max_seq_len_cached = self.max_seq_len_cached
222+ # rope_type = self.rope_type
223+ # max_seq_len_cached = self.max_seq_len_cached
224 original_inv_freq = self.original_inv_freq
225 prefix = ""
226 else:
227- rope_type = self.rope_type[layer_type]
228- max_seq_len_cached = getattr(
229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230- )
231+ # rope_type = self.rope_type[layer_type]
232+ # max_seq_len_cached = getattr(
233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+ # )
235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236 prefix = f"{layer_type}_"
237
238- if seq_len > max_seq_len_cached: # growth
239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240- inv_freq, self.attention_scaling = rope_init_fn(
241- self.config,
242- device,
243- seq_len=seq_len,
244- layer_type=layer_type,
245- )
246- # TODO joao: may break with compilation
247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+ # Second test to translate.
250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+ # But in that case the following condition is a way to restore the original cache.
252
253- if (
254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255- ): # reset
256- # This .to() is needed if the model has been moved to a device after being initialized (because
257- # the buffer is automatically moved, but not the original copy)
258- original_inv_freq = original_inv_freq.to(device)
259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+ # if (
263+ # seq_len < self.original_max_seq_len
264+ # and self.max_seq_len_cached > self.original_max_seq_len
265+ # ):
266+ # self.original_inv_freq = self.original_inv_freq.to(device)
267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+ # self.max_seq_len_cached = self.original_max_seq_len
269+
270+ original_inv_freq = self.original_inv_freq.to(device)
271+ cond = (seq_len >= self.original_max_seq_len).item()
272+ # PATCHED: uses torch.cond instead of a test
273+ inv_freq = torch.cond(
274+ cond,
275+ (lambda x, y: x.clone()),
276+ (lambda x, y: y.clone()),
277+ [long_inv_freq, original_inv_freq],
278+ )
279+ setattr(self, f"{prefix}inv_freq", inv_freq)
280
281 @wraps(rope_forward)
282 def wrapper(self, x, position_ids, layer_type=None):
283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285- if "dynamic" in rope_type:
286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287- elif rope_type == "longrope":
288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289- return rope_forward(self, x, position_ids, **kwargs)
290+ if layer_type is None:
291+ if "dynamic" in self.rope_type:
292+ dynamic_frequency_update(self, position_ids, device=x.device)
293+ elif self.rope_type == "longrope":
294+ longrope_frequency_update(self, position_ids, device=x.device)
295+ return rope_forward(self, x, position_ids)
296+
297+ if "dynamic" in self.rope_type:
298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+ elif self.rope_type == "longrope":
300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+ return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303 return wrapper
auto/patch_transformers: VisionAttention.forward -> patched_VisionAttention.forward¶
1--- original
2+++ rewritten
3@@ -3,69 +3,55 @@
4 hidden_states: torch.Tensor,
5 cu_seqlens: torch.Tensor,
6 rotary_pos_emb: Optional[torch.Tensor] = None,
7- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
8- **kwargs,
9+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
10 ) -> torch.Tensor:
11 seq_length = hidden_states.shape[0]
12- query_states, key_states, value_states = (
13+ q, k, v = (
14 self.qkv(hidden_states)
15 .reshape(seq_length, 3, self.num_heads, -1)
16 .permute(1, 0, 2, 3)
17 .unbind(0)
18 )
19- cos, sin = position_embeddings
20- query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
21+ if position_embeddings is None:
22+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
23+ "The attention layers in this model are transitioning from "
24+ " computing the RoPE embeddings internally "
25+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
26+ "to using externally computed "
27+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
28+ " In v4.54 `rotary_pos_emb` will be "
29+ "removed and `position_embeddings` will be mandatory."
30+ )
31+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
32+ cos = emb.cos()
33+ sin = emb.sin()
34+ else:
35+ cos, sin = position_embeddings
36+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
37+ q, k, cos, sin
38+ )
39
40- query_states = query_states.transpose(0, 1).unsqueeze(0)
41- key_states = key_states.transpose(0, 1).unsqueeze(0)
42- value_states = value_states.transpose(0, 1).unsqueeze(0)
43+ attention_mask = torch.full(
44+ [1, seq_length, seq_length],
45+ torch.finfo(q.dtype).min,
46+ device=q.device,
47+ dtype=q.dtype,
48+ )
49+ # for i in range(1, len(cu_seqlens)):
50+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
51+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
52+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
53
54- attention_interface: Callable = eager_attention_forward
55- if self.config._attn_implementation != "eager":
56- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
57-
58- if self.config._attn_implementation == "flash_attention_2":
59- # Flash Attention 2: Use cu_seqlens for variable length attention
60- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
61- attn_output, _ = attention_interface(
62- self,
63- query_states,
64- key_states,
65- value_states,
66- attention_mask=None,
67- scaling=self.scaling,
68- dropout=0.0 if not self.training else self.attention_dropout,
69- cu_seq_lens_q=cu_seqlens,
70- cu_seq_lens_k=cu_seqlens,
71- max_length_q=max_seqlen,
72- max_length_k=max_seqlen,
73- is_causal=False,
74- **kwargs,
75- )
76- else:
77- # Other implementations: Process each chunk separately
78- lengths = cu_seqlens[1:] - cu_seqlens[:-1]
79- splits = [
80- torch.split(tensor, lengths.tolist(), dim=2)
81- for tensor in (query_states, key_states, value_states)
82- ]
83-
84- attn_outputs = [
85- attention_interface(
86- self,
87- q,
88- k,
89- v,
90- attention_mask=None,
91- scaling=self.scaling,
92- dropout=0.0 if not self.training else self.attention_dropout,
93- is_causal=False,
94- **kwargs,
95- )[0]
96- for q, k, v in zip(*splits)
97- ]
98- attn_output = torch.cat(attn_outputs, dim=1)
99-
100- attn_output = attn_output.reshape(seq_length, -1).contiguous()
101+ q = q.transpose(0, 1)
102+ k = k.transpose(0, 1)
103+ v = v.transpose(0, 1)
104+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / (self.head_dim**0.5)
105+ attn_weights = attn_weights + attention_mask
106+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
107+ q.dtype
108+ )
109+ attn_output = torch.matmul(attn_weights, v)
110+ attn_output = attn_output.transpose(0, 1)
111+ attn_output = attn_output.reshape(seq_length, -1)
112 attn_output = self.proj(attn_output)
113 return attn_output
auto/patch_transformers: eager_attention_forward -> patched_model_bart_eager_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,27 +1,23 @@
4-def eager_attention_forward(
5- module: nn.Module,
6+def patched_model_bart_eager_attention_forward(
7+ module: torch.nn.Module,
8 query: torch.Tensor,
9 key: torch.Tensor,
10 value: torch.Tensor,
11 attention_mask: Optional[torch.Tensor],
12 scaling: Optional[float] = None,
13 dropout: float = 0.0,
14- **kwargs: Unpack[TransformersKwargs],
15+ head_mask: Optional[torch.Tensor] = None,
16+ **kwargs,
17 ):
18- if scaling is None:
19- scaling = query.size(-1) ** -0.5
20-
21- # Take the dot product between "query" and "key" to get the raw attention scores.
22- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24- if attention_mask is not None:
25- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26- attn_weights = attn_weights + attention_mask
27-
28- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31- attn_output = torch.matmul(attn_weights, value)
32- attn_output = attn_output.transpose(1, 2).contiguous()
33-
34- return attn_output, attn_weights
35+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
36+ return common_eager_attention_forward(
37+ module,
38+ query,
39+ key,
40+ value,
41+ attention_mask=attention_mask,
42+ scaling=scaling,
43+ dropout=dropout,
44+ head_mask=head_mask,
45+ **kwargs,
46+ )
auto/patch_transformers: eager_attention_forward -> patched_modeling_marian_eager_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,27 +1,23 @@
4-def eager_attention_forward(
5- module: nn.Module,
6+def patched_modeling_marian_eager_attention_forward(
7+ module: torch.nn.Module,
8 query: torch.Tensor,
9 key: torch.Tensor,
10 value: torch.Tensor,
11 attention_mask: Optional[torch.Tensor],
12 scaling: Optional[float] = None,
13 dropout: float = 0.0,
14- **kwargs: Unpack[TransformersKwargs],
15+ head_mask: Optional[torch.Tensor] = None,
16+ **kwargs,
17 ):
18- if scaling is None:
19- scaling = query.size(-1) ** -0.5
20-
21- # Take the dot product between "query" and "key" to get the raw attention scores.
22- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24- if attention_mask is not None:
25- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26- attn_weights = attn_weights + attention_mask
27-
28- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31- attn_output = torch.matmul(attn_weights, value)
32- attn_output = attn_output.transpose(1, 2).contiguous()
33-
34- return attn_output, attn_weights
35+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
36+ return common_eager_attention_forward(
37+ module,
38+ query,
39+ key,
40+ value,
41+ attention_mask=attention_mask,
42+ scaling=scaling,
43+ dropout=dropout,
44+ head_mask=head_mask,
45+ **kwargs,
46+ )