Patches Diff¶
Patches are not always needed to export a LLM.
Most of the time, only serialization function are needed to export
a LLM with cache (DynamicCache, …).
Function register_additional_serialization_functions
is enough in many cases.
import torch
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
with register_additional_serialization_functions(patch_transformers=True):
ep = torch.export.export(...)
Function torch_export_patches
helps fixing some issues for many models.
import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(...)
Class PatchDetails
gives an example on how to retrieve the list of involded patches for a specific model.
Those patches belongs to the following list which depends on transformers and
pytorch versions.
<<<
import torch
import transformers
print(torch.__version__, transformers.__version__)
>>>
2.10.0.dev20251022+cu130 5.0.0.dev0
Those two versions leads to the following list of patches.
<<<
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_export_patches import torch_export_patches
details = PatchDetails()
with torch_export_patches(
patch_transformers=True,
patch_torch=True,
patch_diffusers=True,
patch_details=details,
):
pass
for patch in details.patched:
if patch.function_to_patch == patch.patch:
continue
rst = patch.format_diff(format="rst")
print()
print()
print(rst)
print()
print()
>>>
sympy: ‘sympy.core.numbers.IntegerConstant.name’ -> _patch_sympy.<locals>.<lambda>¶
1sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
torch: infer_size -> patched_infer_size¶
1--- original
2+++ rewritten
3@@ -1,4 +1,5 @@
4-def infer_size(a, b):
5+def patched_infer_size(a, b):
6+ """Patches ``torch._subclasses.fake_impls.infer_size``."""
7 from torch.fx.experimental.symbolic_shapes import guard_or_false
8
9 dimsA = len(a)
10@@ -23,11 +24,21 @@
11 # expression of an or statement as-is, without bool()'ing it; if this
12 # were not the case, we'd need to write this using torch.sym_or() or
13 # something like that).
14- torch._check(
15- guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
16- lambda: f"The size of tensor a ({sizeA}) "
17- f"must match the size of tensor b ({sizeB}) "
18- f"at non-singleton dimension {i})",
19- )
20- expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
21+ try:
22+ b1 = guard_or_false(sizeA == 1)
23+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
24+ b1 = False
25+ try:
26+ b2 = guard_or_false(sizeB == 1)
27+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
28+ b2 = False
29+ try:
30+ b3 = guard_or_false(sizeA == sizeB)
31+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
32+ b3 = False
33+ if b1 or b2 or b3:
34+ expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
35+ else:
36+ # PATCHED: generic case, the dimension is known, no need to assert
37+ expandedSizes[i] = torch.sym_max(sizeA, sizeB)
38 return tuple(expandedSizes)
torch: _broadcast_shapes -> patched__broadcast_shapes¶
1--- original
2+++ rewritten
3@@ -1,5 +1,11 @@
4-def _broadcast_shapes(*_shapes):
5- from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
6+def patched__broadcast_shapes(*_shapes):
7+ """Patches ``torch._refs._broadcast_shapes``."""
8+ from functools import reduce
9+ from torch._prims_common import IntLike
10+ from torch.fx.experimental.symbolic_shapes import (
11+ guard_or_false,
12+ is_nested_int,
13+ )
14
15 shapes = tuple(
16 (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
17@@ -12,17 +18,15 @@
18 for shape in shapes:
19 if not isinstance(shape, Sequence):
20 raise RuntimeError(
21- "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
22+ "Input shapes should be of type ints, a tuple of ints, "
23+ "or a list of ints, got ",
24 shape,
25 )
26
27 # Computes common shape
28- common_shape: list[Union[int, torch.SymInt]] = [
29- 1,
30- ] * reduce(max, (len(shape) for shape in shapes))
31- for arg_idx, shape in enumerate(shapes):
32+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
33+ for _arg_idx, shape in enumerate(shapes):
34 for idx in range(-1, -1 - len(shape), -1):
35- # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
36 if is_nested_int(shape[idx]):
37 # Broadcasting is allowed for (j0, 1) or (j0, j0);
38 # not (j0, j1), (j0, 5), etc.
39@@ -33,22 +37,15 @@
40 else:
41 if guard_or_false(shape[idx] == common_shape[idx]):
42 continue
43-
44- if guard_or_false(common_shape[idx] == 1):
45+ # PATCHED: two cases, if == for sure, no broadcast,
46+ # otherwise maybe broadcast with max(dimensions)
47+ if guard_or_false(common_shape[idx] != 1):
48+ pass
49+ elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
50 if shape[idx] < 0:
51 raise ValueError("Attempting to broadcast a dimension with negative length!")
52 common_shape[idx] = shape[idx]
53-
54- if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
55- # broadcast case .
56- continue
57 else:
58- # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
59- torch._check(
60- common_shape[idx] == shape[idx],
61- lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
62- f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
63- f"should be broadcastable to {common_shape}",
64- )
65+ common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
66
67 return common_shape
torch: _constrain_user_specified_dimhint_range -> patched__constrain_user_specified_dimhint_range¶
1--- original
2+++ rewritten
3@@ -1,28 +1,31 @@
4-def _constrain_user_specified_dimhint_range(
5+def patched__constrain_user_specified_dimhint_range(
6 symint: torch.SymInt,
7 hint: int,
8- dim: _DimHint,
9+ dim: "_DimHint", # noqa: F821
10 range_constraints,
11 shape_env,
12- keypath: KeyPath,
13+ keypath: "KeyPath", # noqa: F821
14 i: Optional[int] = None,
15 ) -> Optional[str]:
16+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
17+ from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
18+
19 trace_vr = (
20 range_constraints[symint.node.expr]
21 if not is_int(symint)
22 else ValueRanges(int(symint), int(symint))
23 )
24-
25 # warn on 0/1 specialization for Dim.AUTO; not an actual error
26- if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
27- pathstr = f"inputs{pytree.keystr(keypath)}"
28- if i is not None:
29- pathstr += f".shape[{i}]"
30- msg = (
31- f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
32- + f"with a sample input with hint = {hint}."
33- )
34- log.warning(msg)
35+ # PATCHED: remove logging
36+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
37+ # pathstr = f"inputs{pytree.keystr(keypath)}"
38+ # if i is not None:
39+ # pathstr += f".shape[{i}]"
40+ # msg = (
41+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
42+ # f"with a sample input with hint = {hint}."
43+ # )
44+ # log.warning(msg)
45
46 try:
47 user_vr = ValueRanges(
48@@ -38,32 +41,40 @@
49
50 # check for Dim.DYNAMIC specializations; special case error message on 0/1
51 if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
52- path = f"inputs{pytree.keystr(keypath)}"
53+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
54 if i is not None:
55 path += f".shape[{i}]"
56 if (
57 trace_vr.is_singleton()
58 and hint in (0, 1)
59- and not torch.fx.experimental._config.backed_size_oblivious
60+ # PATCHED: line removed
61+ # and not torch.fx.experimental._config.backed_size_oblivious
62 ):
63- msg = (
64- f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
65- f"but export 0/1 specialized due to hint of {hint} for dimension {path}."
66- )
67+ return None
68+ # PATCHED: line removed
69+ # msg = (
70+ # f"- Received user-specified dim hint "
71+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
72+ # f"but export 0/1 specialized due to hint of "
73+ # f"{hint} for dimension {path}."
74+ # )
75 else:
76 msg = (
77- f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
78- f"but tracing inferred a static shape of {out_vr.lower} for dimension {path}."
79+ f"- Received user-specified dim hint "
80+ f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
81+ f"but tracing inferred a static shape of "
82+ f"{out_vr.lower} for dimension {path}."
83 )
84 return msg
85
86 except torch.utils._sympy.value_ranges.ValueRangeError:
87- path = f"inputs{pytree.keystr(keypath)}"
88+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
89 if i is not None:
90 path += f".shape[{i}]"
91 msg = (
92 f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
93- f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], "
94+ f"conflicting with the inferred min/max range of "
95+ f"[{trace_vr.lower}, {trace_vr.upper}], "
96 f"for {path}."
97 )
98 return msg
torch: _broadcast_in_dim_meta -> patched__broadcast_in_dim_meta¶
1--- original
2+++ rewritten
3@@ -1,6 +1,9 @@
4-def _broadcast_in_dim_meta(
5- a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
6+def patched__broadcast_in_dim_meta(
7+ a: torch._prims_common.TensorLikeType,
8+ shape: torch._prims_common.ShapeType,
9+ broadcast_dimensions: Sequence[int],
10 ):
11+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
12 from torch.fx.experimental.symbolic_shapes import (
13 guard_or_false,
14 guard_or_true,
15@@ -8,7 +11,7 @@
16 )
17
18 # Type checks
19- assert isinstance(a, TensorLike)
20+ assert isinstance(a, torch._prims_common.TensorLike)
21 assert isinstance(shape, Sequence)
22 assert isinstance(broadcast_dimensions, Sequence)
23
24@@ -22,7 +25,7 @@
25 # (no relative reordering of dims) of integers and
26 # each dimension must be within the new shape
27 def _greater_than_reduce(acc, x):
28- assert isinstance(x, Dim)
29+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
30 assert x > acc
31 assert x < len(shape)
32
33@@ -34,7 +37,9 @@
34 for idx, new_idx in enumerate(broadcast_dimensions):
35 torch._check(
36 sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
37- lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
38+ lambda idx=idx, new_idx=new_idx: (
39+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
40+ ),
41 )
42
43 new_strides = []
44@@ -48,10 +53,26 @@
45 new_strides.append(a.stride()[original_idx])
46 else:
47 new_strides.append(0)
48+ # PATCHED: disabled this check
49+ elif guard_or_false(a.shape[original_idx] != 1):
50+ new_strides.append(a.stride()[original_idx])
51 else:
52+ # This checks generates the following issue:
53+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
54+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
55+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
56+ # original_idx=1
57 torch._check(
58 a.shape[original_idx] == shape[idx],
59- lambda: f"non-broadcasting semantics require {a.shape[original_idx]} == {shape[idx]}",
60+ lambda idx=idx, original_idx=original_idx: (
61+ f"non-broadcasting semantics require "
62+ f"{a.shape[original_idx]} == {shape[idx]}, "
63+ f"{guard_or_false(a.shape[idx] != 1)}, "
64+ f"guard_or_false(a.shape[idx]==1)="
65+ f"{guard_or_false(a.shape[idx] == 1)}, "
66+ f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
67+ f"shape={shape}, original_idx={original_idx}"
68+ ),
69 )
70 new_strides.append(a.stride()[original_idx])
71 original_idx = original_idx + 1
torch: _maybe_broadcast -> patched__maybe_broadcast¶
1--- original
2+++ rewritten
3@@ -1,6 +1,9 @@
4-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
5+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
6+ """Patches ``torch._refs._maybe_broadcast``."""
7+ from torch._prims_common import ShapeType, TensorLike, Number
8+
9 # Computes common shape
10- common_shape = _broadcast_shapes(
11+ common_shape = patched__broadcast_shapes(
12 *(t.shape if isinstance(t, TensorLike) else None for t in args)
13 )
14
15@@ -29,10 +32,15 @@
16 return True
17
18 # u0==u1 assume the same, no broadcasting!
19- torch._check(
20- x == y,
21- lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
22- )
23+ # PATCHED: avoid errors
24+ return True # guard_or_true(x != y)
25+ # torch._check(
26+ # x == y,
27+ # lambda x=x, y=y: (
28+ # f"sizes assumed to be the same due to unbacked "
29+ # f"broadcasting semantics x={x!r}, y={y!r}"
30+ # ),
31+ # )
32
33 return False
34
35@@ -42,7 +50,7 @@
36 elif isinstance(x, Number):
37 return x
38 elif isinstance(x, TensorLike):
39- if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
40+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
41 return x
42
43 if should_expand(x.shape, common_shape):
44@@ -50,6 +58,6 @@
45
46 return x
47 else:
48- raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
49+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
50
51 return tuple(__maybe_broadcast(x, common_shape) for x in args)
torch: ShapeEnv._evaluate_expr -> patched_ShapeEnv._evaluate_expr¶
1--- original
2+++ rewritten
3@@ -1,14 +1,24 @@
4 def _evaluate_expr(
5 self,
6- orig_expr: sympy.Basic,
7+ orig_expr: "sympy.Basic", # noqa: F821
8 hint: Optional[Union[bool, int, float]] = None,
9 fx_node: Optional[torch.fx.Node] = None,
10 size_oblivious: bool = False,
11 fallback_value: Optional[bool] = None,
12 *,
13 forcing_spec: bool = False,
14-) -> sympy.Basic:
15+) -> "sympy.Basic": # noqa: F821
16 # TODO: split conjunctions and evaluate them separately
17+ import sympy
18+ from torch.fx.experimental import _config as config
19+ from torch.fx.experimental.symbolic_shapes import (
20+ SympyBoolean,
21+ log,
22+ SymT,
23+ symbol_is_type,
24+ )
25+ from torch._guards import ShapeGuard
26+
27 if isinstance(
28 orig_expr,
29 (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
30@@ -118,7 +128,8 @@
31 self._log_suppressed_dde(orig_expr, fallback_value)
32 return fallback_value
33
34- # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
35+ # oblivious_var_to_val will be defined iff we have sizes
36+ # with DimDynamic.OBLIVIOUS_SIZE type.
37 # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
38 if (
39 self.oblivious_var_to_val
40@@ -145,7 +156,8 @@
41 ok = True
42
43 # unbacked_var_to_val is not None iff propagate_real_tensors is on.
44- # if propagate_real_tensors is on, we check the example values to generate (unsound_result)
45+ # if propagate_real_tensors is on, we check the example values
46+ # to generate (unsound_result)
47 # and if they pass we add a runtime assertions and continue.
48 if (
49 not ok
50@@ -163,19 +175,22 @@
51 concrete_val = unsound_result
52 ok = True
53
54- # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
55+ # Check if this is coming from a python assert statement,
56+ # if so, convert it to a runtime assertion
57 # instead of failing.
58 if not ok and self.trace_asserts and self._is_python_assert():
59 concrete_val = sympy.true
60 transmute_into_runtime_assert = True
61 ok = True
62
63- if not ok:
64- raise self._make_data_dependent_error(
65- expr.xreplace(self.var_to_val),
66- expr,
67- expr_sym_node_id=self._expr_sym_node_id,
68- )
69+ # PATCHED: ok -> True
70+ ok = True
71+ # if not ok:
72+ # raise self._make_data_dependent_error(
73+ # expr.xreplace(self.var_to_val),
74+ # expr,
75+ # expr_sym_node_id=self._expr_sym_node_id,
76+ # )
77 else:
78 expr = new_expr
patch_transformers: dynamic_rope_update -> patched_dynamic_rope_update¶
1--- original
2+++ rewritten
3@@ -1,99 +1,193 @@
4-def dynamic_rope_update(rope_forward):
5- """
6- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
7- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
8+def patched_dynamic_rope_update(rope_forward):
9+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
10
11- Args:
12- rope_forward (Callable):
13- The forward pass of the RoPE implementation.
14+ ``rope_type`` is determined in the constructor of class
15+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
16
17- Returns:
18- The decorated forward pass.
19+ .. code-block:: python
20+
21+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
22+ self.rope_type = config.rope_scaling.get(
23+ "rope_type", config.rope_scaling.get("type"))
24+ else:
25+ self.rope_type = "default"
26+
27+ The original code of the patched function:
28+
29+ .. code-block:: python
30+
31+ def dynamic_rope_update(rope_forward):
32+ def longrope_frequency_update(self, position_ids, device):
33+ seq_len = torch.max(position_ids) + 1
34+ if hasattr(self.config, "original_max_position_embeddings"):
35+ original_max_position_embeddings =
36+ self.config.original_max_position_embeddings
37+ else:
38+ original_max_position_embeddings =
39+ self.config.max_position_embeddings
40+ if seq_len > original_max_position_embeddings:
41+ if not hasattr(self, "long_inv_freq"):
42+ self.long_inv_freq, _ = self.rope_init_fn(
43+ self.config, device, seq_len=original_max_position_embeddings + 1
44+ )
45+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
46+ else:
47+ self.original_inv_freq = self.original_inv_freq.to(device)
48+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
49+
50+ def dynamic_frequency_update(self, position_ids, device):
51+ seq_len = torch.max(position_ids) + 1
52+ if seq_len > self.max_seq_len_cached: # growth
53+ inv_freq, self.attention_scaling = self.rope_init_fn(
54+ self.config, device, seq_len=seq_len)
55+ self.register_buffer("inv_freq", inv_freq, persistent=False)
56+ self.max_seq_len_cached = seq_len
57+
58+ if seq_len < self.original_max_seq_len and
59+ self.max_seq_len_cached > self.original_max_seq_len:
60+ self.original_inv_freq = self.original_inv_freq.to(device)
61+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
62+ self.max_seq_len_cached = self.original_max_seq_len
63+
64+ @wraps(rope_forward)
65+ def wrapper(self, x, position_ids):
66+ if "dynamic" in self.rope_type:
67+ dynamic_frequency_update(self, position_ids, device=x.device)
68+ elif self.rope_type == "longrope":
69+ longrope_frequency_update(self, position_ids, device=x.device)
70+ return rope_forward(self, x, position_ids)
71+
72+ return wrapper
73+
74 """
75
76 def longrope_frequency_update(self, position_ids, device, layer_type=None):
77- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
78+ # It is no use to patch the function after the model is created
79+ # as rope_init_fn is an attribute set to one function when the model
80+ # is created and when no patch is applied yet.
81+ # So we select the patched version here.
82+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
83 seq_len = torch.max(position_ids) + 1
84- original_max_position_embeddings = getattr(
85- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
86- )
87+ if hasattr(self.config, "original_max_position_embeddings"):
88+ original_max_position_embeddings = self.config.original_max_position_embeddings
89+ else:
90+ original_max_position_embeddings = self.config.max_position_embeddings
91+
92 if layer_type is None:
93- rope_type = self.rope_type
94+ # rope_type = self.rope_type
95 original_inv_freq = self.original_inv_freq
96 prefix = ""
97 else:
98- rope_type = self.rope_type[layer_type]
99+ # rope_type = self.rope_type[layer_type]
100 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
101 prefix = f"{layer_type}_"
102
103- if seq_len > original_max_position_embeddings:
104- if not hasattr(self, f"{layer_type}_long_inv_freq"):
105- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
106- long_inv_freq, _ = rope_init_fn(
107- self.config,
108- device,
109- seq_len=original_max_position_embeddings + 1,
110- layer_type=layer_type,
111- )
112- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
113- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
114- else:
115- # This .to() is needed if the model has been moved to a device after being initialized (because
116- # the buffer is automatically moved, but not the original copy)
117- original_inv_freq = original_inv_freq.to(device)
118- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
119- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
120+ # At export time, seq_len is unknown.
121+ long_inv_freq, _ = rope_init_fn(
122+ self.config, device, seq_len=original_max_position_embeddings + 1
123+ )
124+ original_inv_freq = self.original_inv_freq.to(device)
125+
126+ # PATCHED: uses torch.cond instead of a test
127+ cond = (seq_len > original_max_position_embeddings).item()
128+ inv_freq = torch.cond(
129+ cond,
130+ (lambda x, y: x.clone()),
131+ (lambda x, y: y.clone()),
132+ [long_inv_freq, original_inv_freq],
133+ )
134+ setattr(self, f"{prefix}inv_freq", inv_freq)
135+ # if seq_len > original_max_position_embeddings:
136+ # self.inv_freq = self.long_inv_freq
137+ # else:
138+ # self.inv_freq = self.original_inv_freq
139
140 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
141- """
142- dynamic RoPE layers should recompute `inv_freq` in the following situations:
143- 1 - growing beyond the cached sequence length (allow scaling)
144- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
145- """
146+ # constructor:
147+ # - self.max_seq_len_cached = config.max_position_embeddings
148+ # - self.original_max_seq_len = config.max_position_embeddings
149+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
150+
151+ # It is no use to patch the function after the model is created
152+ # as rope_init_fn is an attribute set to one function when the model
153+ # is created and when no patch is applied yet.
154+ # So we select the patched version here.
155+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
156+
157+ # This behaviour is difficult to translate.
158+ # The sequence always grows.
159+ # The test should always True.
160+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
161+ #
162+ # if seq_len > self.max_seq_len_cached: # growth
163+ # inv_freq, self.attention_scaling = self.rope_init_fn(
164+ # self.config, device, seq_len=seq_len
165+ # )
166+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
167+ # self.max_seq_len_cached = seq_len
168+ #
169+ # So we should not need what follows.
170+ #
171+ # cond = (seq_len > self.max_seq_len_cached).item()
172+ # self.attention_scaling = torch.cond(
173+ # cond,
174+ # (lambda x, y: x.clone()),
175+ # (lambda x, y: y.clone()),
176+ # [attention_scaling, self.attention_scaling],
177+ # )
178+
179 seq_len = torch.max(position_ids) + 1
180+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
181+
182 if layer_type is None:
183- rope_type = self.rope_type
184- max_seq_len_cached = self.max_seq_len_cached
185+ # rope_type = self.rope_type
186+ # max_seq_len_cached = self.max_seq_len_cached
187 original_inv_freq = self.original_inv_freq
188 prefix = ""
189 else:
190- rope_type = self.rope_type[layer_type]
191- max_seq_len_cached = getattr(
192- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
193- )
194+ # rope_type = self.rope_type[layer_type]
195+ # max_seq_len_cached = getattr(
196+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
197+ # )
198 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
199 prefix = f"{layer_type}_"
200
201- if seq_len > max_seq_len_cached: # growth
202- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
203- inv_freq, self.attention_scaling = rope_init_fn(
204- self.config,
205- device,
206- seq_len=seq_len,
207- layer_type=layer_type,
208- )
209- # TODO joao: may break with compilation
210- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
211- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
212+ # Second test to translate.
213+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
214+ # But in that case the following condition is a way to restore the original cache.
215
216- if (
217- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
218- ): # reset
219- # This .to() is needed if the model has been moved to a device after being initialized (because
220- # the buffer is automatically moved, but not the original copy)
221- original_inv_freq = original_inv_freq.to(device)
222- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
223- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
224- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
225+ # if (
226+ # seq_len < self.original_max_seq_len
227+ # and self.max_seq_len_cached > self.original_max_seq_len
228+ # ):
229+ # self.original_inv_freq = self.original_inv_freq.to(device)
230+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
231+ # self.max_seq_len_cached = self.original_max_seq_len
232+
233+ original_inv_freq = self.original_inv_freq.to(device)
234+ cond = (seq_len >= self.original_max_seq_len).item()
235+ # PATCHED: uses torch.cond instead of a test
236+ inv_freq = torch.cond(
237+ cond,
238+ (lambda x, y: x.clone()),
239+ (lambda x, y: y.clone()),
240+ [long_inv_freq, original_inv_freq],
241+ )
242+ setattr(self, f"{prefix}inv_freq", inv_freq)
243
244 @wraps(rope_forward)
245 def wrapper(self, x, position_ids, layer_type=None):
246- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
247- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
248- if "dynamic" in rope_type:
249- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
250- elif rope_type == "longrope":
251- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
252- return rope_forward(self, x, position_ids, **kwargs)
253+ if layer_type is None:
254+ if "dynamic" in self.rope_type:
255+ dynamic_frequency_update(self, position_ids, device=x.device)
256+ elif self.rope_type == "longrope":
257+ longrope_frequency_update(self, position_ids, device=x.device)
258+ return rope_forward(self, x, position_ids)
259+
260+ if "dynamic" in self.rope_type:
261+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
262+ elif self.rope_type == "longrope":
263+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
264+ return rope_forward(self, x, position_ids, layer_type=layer_type)
265
266 return wrapper
transformers: _vmap_for_bhqkv -> patched__vmap_for_bhqkv¶
1--- original
2+++ rewritten
3@@ -1,25 +1,50 @@
4-def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
5- """
6- Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
7- the batch and head indices as well if `bh_indices=True`.
8- Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
9- functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
10+def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
11+ """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``."""
12+ from ...helpers import string_type
13
14- Args:
15- mask_function (`Callable`):
16- The mask_function to vmap.
17- bh_indices (`bool`, optional):
18- Whether to vmap over the batch and head indices as well, or only q and kv indices.
19+ dimensions: List[Tuple[Optional[int], ...]] = [
20+ (None, None, None, 0),
21+ (None, None, 0, None),
22+ ]
23+ if bh_indices:
24+ dimensions.extend([(None, 0, None, None), (0, None, None, None)])
25+ # reshape
26+ dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
27+ dimensions = tuple(reversed(dimensions))
28+ indices = tuple(shape.index(-1) for shape in dimensions)
29
30- Returns:
31- Callable: The vmapped function.
32- """
33- # We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
34- dimensions = [(None, None, None, 0), (None, None, 0, None)]
35- if bh_indices:
36- # We extend broadcasting over the [batch_idx, head_idx] dimensions
37- dimensions.extend([(None, 0, None, None), (0, None, None, None)])
38+ # unsqueeze
39+ udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
40
41- for dims in dimensions:
42- mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
43- return mask_function
44+ def vector_mask_function(
45+ *args, mask_function=mask_function, dimensions=dimensions, indices=indices
46+ ):
47+ assert len(args) == len(dimensions) == len(udimensions), (
48+ f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
49+ f"and udimensions={udimensions}."
50+ )
51+ assert len(indices) == len(args), (
52+ f"Mismatch between args={string_type(args)} and indices={indices}, "
53+ f"they should have the same length."
54+ )
55+ for a in args:
56+ assert (
57+ a.ndim == 1
58+ ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
59+ torch._check(a.shape[0] > 0)
60+
61+ new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
62+ # new_args = [
63+ # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
64+ # for a, dims in zip(args, udimensions)
65+ # ]
66+ max_shape = tuple(args[i].shape[0] for i in indices)
67+ # if _is_torchdynamo_exporting():
68+ # for a in args:
69+ # # The exporter should export with a dimension > 1
70+ # # to make sure it is dynamic.
71+ # torch._check(a.shape[0] > 1)
72+ expanded_args = [a.expand(max_shape) for a in new_args]
73+ return mask_function(*expanded_args)
74+
75+ return vector_mask_function
transformers: sdpa_mask_recent_torch -> patched_sdpa_mask_recent_torch¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def sdpa_mask_recent_torch(
5+def patched_sdpa_mask_recent_torch(
6 batch_size: int,
7 cache_position: torch.Tensor,
8 kv_length: int,
9@@ -10,145 +10,42 @@
10 allow_is_bidirectional_skip: bool = False,
11 **kwargs,
12 ) -> Optional[torch.Tensor]:
13- """
14- Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
15- the element should take part in the attention computation, and False that it should not.
16- This function can only be used with torch>=2.5, as the context manager is otherwise not available.
17-
18- Args:
19- batch_size (`int`):
20- The batch size of the input sequence.
21- cache_position (`torch.Tensor`):
22- A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
23- kv_length (`int`):
24- The size that the key and value states will have during the attention computation.
25- kv_offset (`int`, optional):
26- An optional offset to indicate at which first position the key and values states will refer to.
27- mask_function (`Callable`):
28- The mask factory function describing the mask pattern.
29- attention_mask (`torch.Tensor`, optional):
30- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
31- local_size (`int`, optional):
32- The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
33- to try to skip mask creation if possible.
34- allow_is_causal_skip (`bool`, optional):
35- Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
36- `torch.sdpa` instead. Default to `True`.
37- allow_torch_fix (`bool`, optional):
38- Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
39- versions. We need an arg to skip it when using eager. By default `True`.
40- allow_is_bidirectional_skip (`bool`, optional):
41- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
42- i.e. full attention without any padding. Default to `False`.
43-
44-
45- ## Creating a simple causal mask:
46-
47- To create the following causal mask:
48-
49- 0 ■ ⬚ ⬚ ⬚ ⬚
50- 1 ■ ■ ⬚ ⬚ ⬚
51- 2 ■ ■ ■ ⬚ ⬚
52- 3 ■ ■ ■ ■ ⬚
53- 4 ■ ■ ■ ■ ■
54-
55- You can do
56-
57- ```python
58- >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
59- >>> tensor([[[[ True, False, False, False, False],
60- [ True, True, False, False, False],
61- [ True, True, True, False, False],
62- [ True, True, True, True, False],
63- [ True, True, True, True, True]]]])
64- ```
65-
66- ## Creating a sliding window mask:
67-
68- To create the following sliding window mask (`sliding_window=3`):
69-
70- 0 ■ ⬚ ⬚ ⬚ ⬚
71- 1 ■ ■ ⬚ ⬚ ⬚
72- 2 ■ ■ ■ ⬚ ⬚
73- 3 ⬚ ■ ■ ■ ⬚
74- 4 ⬚ ⬚ ■ ■ ■
75-
76- You can do
77-
78- ```python
79- >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
80- >>> tensor([[[[ True, False, False, False, False],
81- [ True, True, False, False, False],
82- [ True, True, True, False, False],
83- [False, True, True, True, False],
84- [False, False, True, True, True]]]])
85- ```
86-
87- ## Creating a chunked attention mask
88-
89- To create the following chunked attention mask (`chunk_size=3`):
90-
91- 0 ■ ⬚ ⬚ ⬚ ⬚
92- 1 ■ ■ ⬚ ⬚ ⬚
93- 2 ■ ■ ■ ⬚ ⬚
94- 3 ⬚ ⬚ ⬚ ■ ⬚
95- 4 ⬚ ⬚ ⬚ ■ ■
96-
97- You can do
98-
99- ```python
100- >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
101- >>> tensor([[[[ True, False, False, False, False],
102- [ True, True, False, False, False],
103- [ True, True, True, False, False],
104- [False, False, False, True, False],
105- [False, False, False, True, True]]]])
106- ```
107-
108- """
109+ """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
110 q_length = cache_position.shape[0]
111- # Potentially pad the 2D mask, and slice it correctly
112 padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
113-
114- # Under specific conditions, we can avoid materializing the mask
115- # 1. Causal masks can rely on the `is_causal` argument
116- # 2. Bidirectional do not need any further processing (no bias)
117 if allow_is_causal_skip and _ignore_causal_mask_sdpa(
118 padding_mask, q_length, kv_length, kv_offset, local_size
119 ):
120 return None
121- if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
122+ if (
123+ allow_is_bidirectional_skip
124+ and _ignore_bidirectional_mask_sdpa
125+ and _ignore_bidirectional_mask_sdpa(padding_mask)
126+ ):
127 return None
128
129- # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
130- # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
131 if mask_function is bidirectional_mask_function:
132 if padding_mask is not None:
133 # used for slicing without data-dependent slicing
134 mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
135 return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
136- else:
137- return torch.ones(
138- batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device
139- )
140+ return torch.ones(
141+ batch_size,
142+ 1,
143+ q_length,
144+ kv_length,
145+ dtype=torch.bool,
146+ device=cache_position.device,
147+ )
148
149- # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
150- # but without data-dependent slicing (i.e. torch.compile friendly)
151 kv_arange = torch.arange(kv_length, device=cache_position.device)
152 kv_arange += kv_offset
153-
154- # Potentially add the padding 2D mask
155 if padding_mask is not None:
156 mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
157-
158 batch_arange = torch.arange(batch_size, device=cache_position.device)
159 head_arange = torch.arange(1, device=cache_position.device)
160- # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
161- # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
162- # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
163- with TransformGetItemToIndex():
164- causal_mask = _vmap_for_bhqkv(mask_function)(
165- batch_arange, head_arange, cache_position, kv_arange
166- )
167-
168+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
169+ causal_mask = patched__vmap_for_bhqkv(mask_function)(
170+ batch_arange, head_arange, cache_position, kv_arange
171+ )
172 return causal_mask
transformers: sdpa_mask_recent_torch -> patched_sdpa_mask_recent_torch¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def sdpa_mask_recent_torch(
5+def patched_sdpa_mask_recent_torch(
6 batch_size: int,
7 cache_position: torch.Tensor,
8 kv_length: int,
9@@ -10,145 +10,42 @@
10 allow_is_bidirectional_skip: bool = False,
11 **kwargs,
12 ) -> Optional[torch.Tensor]:
13- """
14- Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
15- the element should take part in the attention computation, and False that it should not.
16- This function can only be used with torch>=2.5, as the context manager is otherwise not available.
17-
18- Args:
19- batch_size (`int`):
20- The batch size of the input sequence.
21- cache_position (`torch.Tensor`):
22- A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
23- kv_length (`int`):
24- The size that the key and value states will have during the attention computation.
25- kv_offset (`int`, optional):
26- An optional offset to indicate at which first position the key and values states will refer to.
27- mask_function (`Callable`):
28- The mask factory function describing the mask pattern.
29- attention_mask (`torch.Tensor`, optional):
30- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
31- local_size (`int`, optional):
32- The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
33- to try to skip mask creation if possible.
34- allow_is_causal_skip (`bool`, optional):
35- Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
36- `torch.sdpa` instead. Default to `True`.
37- allow_torch_fix (`bool`, optional):
38- Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
39- versions. We need an arg to skip it when using eager. By default `True`.
40- allow_is_bidirectional_skip (`bool`, optional):
41- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
42- i.e. full attention without any padding. Default to `False`.
43-
44-
45- ## Creating a simple causal mask:
46-
47- To create the following causal mask:
48-
49- 0 ■ ⬚ ⬚ ⬚ ⬚
50- 1 ■ ■ ⬚ ⬚ ⬚
51- 2 ■ ■ ■ ⬚ ⬚
52- 3 ■ ■ ■ ■ ⬚
53- 4 ■ ■ ■ ■ ■
54-
55- You can do
56-
57- ```python
58- >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
59- >>> tensor([[[[ True, False, False, False, False],
60- [ True, True, False, False, False],
61- [ True, True, True, False, False],
62- [ True, True, True, True, False],
63- [ True, True, True, True, True]]]])
64- ```
65-
66- ## Creating a sliding window mask:
67-
68- To create the following sliding window mask (`sliding_window=3`):
69-
70- 0 ■ ⬚ ⬚ ⬚ ⬚
71- 1 ■ ■ ⬚ ⬚ ⬚
72- 2 ■ ■ ■ ⬚ ⬚
73- 3 ⬚ ■ ■ ■ ⬚
74- 4 ⬚ ⬚ ■ ■ ■
75-
76- You can do
77-
78- ```python
79- >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
80- >>> tensor([[[[ True, False, False, False, False],
81- [ True, True, False, False, False],
82- [ True, True, True, False, False],
83- [False, True, True, True, False],
84- [False, False, True, True, True]]]])
85- ```
86-
87- ## Creating a chunked attention mask
88-
89- To create the following chunked attention mask (`chunk_size=3`):
90-
91- 0 ■ ⬚ ⬚ ⬚ ⬚
92- 1 ■ ■ ⬚ ⬚ ⬚
93- 2 ■ ■ ■ ⬚ ⬚
94- 3 ⬚ ⬚ ⬚ ■ ⬚
95- 4 ⬚ ⬚ ⬚ ■ ■
96-
97- You can do
98-
99- ```python
100- >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
101- >>> tensor([[[[ True, False, False, False, False],
102- [ True, True, False, False, False],
103- [ True, True, True, False, False],
104- [False, False, False, True, False],
105- [False, False, False, True, True]]]])
106- ```
107-
108- """
109+ """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
110 q_length = cache_position.shape[0]
111- # Potentially pad the 2D mask, and slice it correctly
112 padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
113-
114- # Under specific conditions, we can avoid materializing the mask
115- # 1. Causal masks can rely on the `is_causal` argument
116- # 2. Bidirectional do not need any further processing (no bias)
117 if allow_is_causal_skip and _ignore_causal_mask_sdpa(
118 padding_mask, q_length, kv_length, kv_offset, local_size
119 ):
120 return None
121- if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
122+ if (
123+ allow_is_bidirectional_skip
124+ and _ignore_bidirectional_mask_sdpa
125+ and _ignore_bidirectional_mask_sdpa(padding_mask)
126+ ):
127 return None
128
129- # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
130- # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
131 if mask_function is bidirectional_mask_function:
132 if padding_mask is not None:
133 # used for slicing without data-dependent slicing
134 mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
135 return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
136- else:
137- return torch.ones(
138- batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device
139- )
140+ return torch.ones(
141+ batch_size,
142+ 1,
143+ q_length,
144+ kv_length,
145+ dtype=torch.bool,
146+ device=cache_position.device,
147+ )
148
149- # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
150- # but without data-dependent slicing (i.e. torch.compile friendly)
151 kv_arange = torch.arange(kv_length, device=cache_position.device)
152 kv_arange += kv_offset
153-
154- # Potentially add the padding 2D mask
155 if padding_mask is not None:
156 mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
157-
158 batch_arange = torch.arange(batch_size, device=cache_position.device)
159 head_arange = torch.arange(1, device=cache_position.device)
160- # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
161- # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
162- # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
163- with TransformGetItemToIndex():
164- causal_mask = _vmap_for_bhqkv(mask_function)(
165- batch_arange, head_arange, cache_position, kv_arange
166- )
167-
168+ # PATCHED: this line calls the patched version of vmap_for_bhqkv
169+ causal_mask = patched__vmap_for_bhqkv(mask_function)(
170+ batch_arange, head_arange, cache_position, kv_arange
171+ )
172 return causal_mask
transformers: eager_mask -> patched_eager_mask¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def eager_mask(
5+def patched_eager_mask(
6 batch_size: int,
7 cache_position: torch.Tensor,
8 kv_length: int,
9@@ -8,31 +8,12 @@
10 dtype: torch.dtype = torch.float32,
11 **kwargs,
12 ) -> torch.Tensor:
13- """
14- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
15- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
16- it should not.
17-
18- Args:
19- batch_size (`int`):
20- The batch size of the input sequence.
21- cache_position (`torch.Tensor`):
22- A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
23- kv_length (`int`):
24- The size that the key and value states will have during the attention computation.
25- kv_offset (`int`, optional):
26- An optional offset to indicate at which first position the key and values states will refer to.
27- mask_function (`Callable`):
28- The mask factory function describing the mask pattern.
29- attention_mask (`torch.Tensor`, optional):
30- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
31- dtype (`torch.dtype`, optional):
32- The dtype to use for the mask. By default, `torch.float32`.
33- """
34+ """manual patch for function ``transformers.masking_utils.eager_mask``."""
35 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
36 _ = kwargs.pop("allow_is_causal_skip", None)
37 _ = kwargs.pop("allow_is_bidirectional_skip", None)
38- mask = sdpa_mask(
39+ # PATCHED: this line called the patched version of sdpa_mask
40+ mask = patched_sdpa_mask_recent_torch(
41 batch_size=batch_size,
42 cache_position=cache_position,
43 kv_length=kv_length,
44@@ -45,6 +26,10 @@
45 **kwargs,
46 )
47 min_dtype = torch.finfo(dtype).min
48- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
49- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
50+ # PATCHED: the following line
51+ # we need 0s where the tokens should be taken into account,
52+ # and -inf otherwise (mask is already of boolean type)
53+ # mask =
54+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
55+ mask = (~mask).to(dtype) * min_dtype
56 return mask
transformers: sdpa_attention_forward -> patched_sdpa_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def sdpa_attention_forward(
5+def patched_sdpa_attention_forward(
6 module: torch.nn.Module,
7 query: torch.Tensor,
8 key: torch.Tensor,
9@@ -9,60 +9,121 @@
10 is_causal: Optional[bool] = None,
11 **kwargs,
12 ) -> tuple[torch.Tensor, None]:
13- if kwargs.get("output_attentions", False):
14- logger.warning_once(
15- "`sdpa` attention does not support `output_attentions=True`."
16- " Please set your attention to `eager` if you want any of these features."
17- )
18+ """
19+ manual patch for function
20+ ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
21+ """
22+ assert not kwargs.get("output_attentions", False), (
23+ "`sdpa` attention does not support `output_attentions=True`."
24+ " Please set your attention to `eager` if you want any of these features."
25+ )
26+ torch._check(
27+ query.shape[0] == key.shape[0] or query.shape[0] == 1,
28+ lambda: (
29+ f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
30+ f"value: {value.shape}"
31+ ),
32+ )
33+ torch._check(
34+ key.shape[0] == value.shape[0] or key.shape[0] == 1,
35+ lambda: (
36+ f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
37+ f"value: {value.shape}"
38+ ),
39+ )
40+
41 sdpa_kwargs = {}
42 if hasattr(module, "num_key_value_groups"):
43- if not use_gqa_in_sdpa(attention_mask, key):
44- key = repeat_kv(key, module.num_key_value_groups)
45- value = repeat_kv(value, module.num_key_value_groups)
46+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
47+ key = transformers.integrations.sdpa_attention.repeat_kv(
48+ key, module.num_key_value_groups
49+ )
50+ value = transformers.integrations.sdpa_attention.repeat_kv(
51+ value, module.num_key_value_groups
52+ )
53 else:
54 sdpa_kwargs = {"enable_gqa": True}
55
56 if attention_mask is not None and attention_mask.ndim == 4:
57 attention_mask = attention_mask[:, :, :, : key.shape[-2]]
58
59- # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
60- is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
61+ torch._check(
62+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
63+ lambda: "Attention mask shape incompatible with key shape.",
64+ )
65
66- # SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
67- # - Not in decoding phase (otherwise we want full attention on the single query token)
68- # - Attention mask is not to be provided (even if it is a causal pattern)
69- # - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
70- #
71- # Quirks on the conditionals:
72- # - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
73- # full graph options. Otherwise, dynamic shapes are prevented from compiling.
74- # - It is important to check first for the shape, otherwise compile will fail with
75- # `argument 'is_causal' must be bool, not SymBool`.
76- is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
77+ if patch_sdpa_is_causal:
78+ # transformers>=4.55
79+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
80
81- # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
82- # We convert it to a bool for the SDPA kernel that only accepts bools.
83- if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
84- is_causal = is_causal.item()
85+ # PATCHED: remove the test query.shape[2] > 1
86+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
87+ # and we split the test to keep the minimum in torch.cond
88+ is_causal = attention_mask is None and is_causal
89
90- # When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
91- # and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
92- # This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
93- if _is_torch_npu_available:
94- if attention_mask is not None and attention_mask.dtype != torch.bool:
95- # Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
96- attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
97+ if not is_causal:
98+ return (
99+ torch.nn.functional.scaled_dot_product_attention(
100+ query,
101+ key,
102+ value,
103+ attn_mask=attention_mask,
104+ dropout_p=dropout,
105+ scale=scaling,
106+ is_causal=is_causal,
107+ **sdpa_kwargs,
108+ )
109+ .transpose(1, 2)
110+ .contiguous(),
111+ None,
112+ )
113+ else:
114+ # transformers<4.55
115+ if is_causal is None and attention_mask is not None:
116+ is_causal = False
117+ if is_causal is not None:
118+ return (
119+ torch.nn.functional.scaled_dot_product_attention(
120+ query,
121+ key,
122+ value,
123+ attn_mask=attention_mask,
124+ dropout_p=dropout,
125+ scale=scaling,
126+ is_causal=is_causal,
127+ **sdpa_kwargs,
128+ )
129+ .transpose(1, 2)
130+ .contiguous(),
131+ None,
132+ )
133
134- attn_output = torch.nn.functional.scaled_dot_product_attention(
135- query,
136- key,
137- value,
138- attn_mask=attention_mask,
139- dropout_p=dropout,
140- scale=scaling,
141- is_causal=is_causal,
142- **sdpa_kwargs,
143+ # To avoid the following errors:
144+ # is_causal=query.shape[2] > 1
145+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
146+ # is_causal=torch.tensor(query.shape[2] > 1)
147+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
148+ attn_output = torch.cond(
149+ query.shape[2] > 1, # distinction between prefill and decoding steps
150+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
151+ query,
152+ key,
153+ value,
154+ dropout_p=dropout,
155+ scale=scaling,
156+ is_causal=True,
157+ **sdpa_kwargs,
158+ ),
159+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
160+ query,
161+ key,
162+ value,
163+ dropout_p=dropout,
164+ scale=scaling,
165+ is_causal=False,
166+ **sdpa_kwargs,
167+ ),
168+ [query, key, value],
169 )
170 attn_output = attn_output.transpose(1, 2).contiguous()
171-
172 return attn_output, None
auto/patch_transformers: DynamicLayer.lazy_initialization -> patched_DynamicLayer.lazy_initialization¶
1--- original
2+++ rewritten
3@@ -1,5 +1,9 @@
4 def lazy_initialization(self, key_states: torch.Tensor):
5 self.dtype, self.device = key_states.dtype, key_states.device
6- self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
7- self.values = torch.tensor([], dtype=self.dtype, device=self.device)
8- self.is_initialized = True
9+ new_shape = list(key_states.shape)
10+ new_shape[-2] = 0
11+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
12+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
13+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
14+ if patch_is_initialized:
15+ self.is_initialized = True
auto/patch_transformers: Gemma2RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Gemma3Model.get_placeholder_mask -> patched_Gemma3Model.get_placeholder_mask¶
1--- original
2+++ rewritten
3@@ -4,14 +4,12 @@
4 inputs_embeds: torch.FloatTensor,
5 image_features: torch.FloatTensor,
6 ):
7- """
8- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
9- equal to the length of multimodal features. If the lengths are different, an error is raised.
10- """
11 if input_ids is None:
12 special_image_mask = inputs_embeds == self.get_input_embeddings()(
13 torch.tensor(
14- self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
15+ self.config.image_token_id,
16+ dtype=torch.long,
17+ device=inputs_embeds.device,
18 )
19 )
20 special_image_mask = special_image_mask.all(-1)
21@@ -23,8 +21,14 @@
22 special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
23 )
24 n_image_features = image_features.shape[0] * image_features.shape[1]
25- if inputs_embeds[special_image_mask].numel() != image_features.numel():
26- raise ValueError(
27- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
28- )
29+ # PATCHED: torch._check
30+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
31+ # raise ValueError( ... )
32+ torch._check(
33+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
34+ lambda: (
35+ f"Image features and image tokens do not match: tokens: "
36+ f"{n_image_tokens}, features {n_image_features}"
37+ ),
38+ )
39 return special_image_mask
auto/patch_transformers: Gemma3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,13 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6+@patched_dynamic_rope_update
7 def forward(self, x, position_ids, layer_type=None):
8- inv_freq = getattr(self, f"{layer_type}_inv_freq")
9- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
10+ if layer_type is not None:
11+ # transformers>=5.0
12+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
13+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
14+ else:
15+ # transformers<5.0
16+ inv_freq = self.inv_freq
17+ attention_scaling = self.attention_scaling
18
19 inv_freq_expanded = (
20 inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21
22--- original
23+++ rewritten
24@@ -1,99 +1,193 @@
25-def dynamic_rope_update(rope_forward):
26- """
27- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
28- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
29+def patched_dynamic_rope_update(rope_forward):
30+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
31
32- Args:
33- rope_forward (Callable):
34- The forward pass of the RoPE implementation.
35+ ``rope_type`` is determined in the constructor of class
36+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
37
38- Returns:
39- The decorated forward pass.
40+ .. code-block:: python
41+
42+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
43+ self.rope_type = config.rope_scaling.get(
44+ "rope_type", config.rope_scaling.get("type"))
45+ else:
46+ self.rope_type = "default"
47+
48+ The original code of the patched function:
49+
50+ .. code-block:: python
51+
52+ def dynamic_rope_update(rope_forward):
53+ def longrope_frequency_update(self, position_ids, device):
54+ seq_len = torch.max(position_ids) + 1
55+ if hasattr(self.config, "original_max_position_embeddings"):
56+ original_max_position_embeddings =
57+ self.config.original_max_position_embeddings
58+ else:
59+ original_max_position_embeddings =
60+ self.config.max_position_embeddings
61+ if seq_len > original_max_position_embeddings:
62+ if not hasattr(self, "long_inv_freq"):
63+ self.long_inv_freq, _ = self.rope_init_fn(
64+ self.config, device, seq_len=original_max_position_embeddings + 1
65+ )
66+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
67+ else:
68+ self.original_inv_freq = self.original_inv_freq.to(device)
69+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
70+
71+ def dynamic_frequency_update(self, position_ids, device):
72+ seq_len = torch.max(position_ids) + 1
73+ if seq_len > self.max_seq_len_cached: # growth
74+ inv_freq, self.attention_scaling = self.rope_init_fn(
75+ self.config, device, seq_len=seq_len)
76+ self.register_buffer("inv_freq", inv_freq, persistent=False)
77+ self.max_seq_len_cached = seq_len
78+
79+ if seq_len < self.original_max_seq_len and
80+ self.max_seq_len_cached > self.original_max_seq_len:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+ self.max_seq_len_cached = self.original_max_seq_len
84+
85+ @wraps(rope_forward)
86+ def wrapper(self, x, position_ids):
87+ if "dynamic" in self.rope_type:
88+ dynamic_frequency_update(self, position_ids, device=x.device)
89+ elif self.rope_type == "longrope":
90+ longrope_frequency_update(self, position_ids, device=x.device)
91+ return rope_forward(self, x, position_ids)
92+
93+ return wrapper
94+
95 """
96
97 def longrope_frequency_update(self, position_ids, device, layer_type=None):
98- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
99+ # It is no use to patch the function after the model is created
100+ # as rope_init_fn is an attribute set to one function when the model
101+ # is created and when no patch is applied yet.
102+ # So we select the patched version here.
103+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
104 seq_len = torch.max(position_ids) + 1
105- original_max_position_embeddings = getattr(
106- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
107- )
108+ if hasattr(self.config, "original_max_position_embeddings"):
109+ original_max_position_embeddings = self.config.original_max_position_embeddings
110+ else:
111+ original_max_position_embeddings = self.config.max_position_embeddings
112+
113 if layer_type is None:
114- rope_type = self.rope_type
115+ # rope_type = self.rope_type
116 original_inv_freq = self.original_inv_freq
117 prefix = ""
118 else:
119- rope_type = self.rope_type[layer_type]
120+ # rope_type = self.rope_type[layer_type]
121 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
122 prefix = f"{layer_type}_"
123
124- if seq_len > original_max_position_embeddings:
125- if not hasattr(self, f"{layer_type}_long_inv_freq"):
126- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
127- long_inv_freq, _ = rope_init_fn(
128- self.config,
129- device,
130- seq_len=original_max_position_embeddings + 1,
131- layer_type=layer_type,
132- )
133- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
134- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
135- else:
136- # This .to() is needed if the model has been moved to a device after being initialized (because
137- # the buffer is automatically moved, but not the original copy)
138- original_inv_freq = original_inv_freq.to(device)
139- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
140- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
141+ # At export time, seq_len is unknown.
142+ long_inv_freq, _ = rope_init_fn(
143+ self.config, device, seq_len=original_max_position_embeddings + 1
144+ )
145+ original_inv_freq = self.original_inv_freq.to(device)
146+
147+ # PATCHED: uses torch.cond instead of a test
148+ cond = (seq_len > original_max_position_embeddings).item()
149+ inv_freq = torch.cond(
150+ cond,
151+ (lambda x, y: x.clone()),
152+ (lambda x, y: y.clone()),
153+ [long_inv_freq, original_inv_freq],
154+ )
155+ setattr(self, f"{prefix}inv_freq", inv_freq)
156+ # if seq_len > original_max_position_embeddings:
157+ # self.inv_freq = self.long_inv_freq
158+ # else:
159+ # self.inv_freq = self.original_inv_freq
160
161 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
162- """
163- dynamic RoPE layers should recompute `inv_freq` in the following situations:
164- 1 - growing beyond the cached sequence length (allow scaling)
165- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
166- """
167+ # constructor:
168+ # - self.max_seq_len_cached = config.max_position_embeddings
169+ # - self.original_max_seq_len = config.max_position_embeddings
170+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
171+
172+ # It is no use to patch the function after the model is created
173+ # as rope_init_fn is an attribute set to one function when the model
174+ # is created and when no patch is applied yet.
175+ # So we select the patched version here.
176+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
177+
178+ # This behaviour is difficult to translate.
179+ # The sequence always grows.
180+ # The test should always True.
181+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
182+ #
183+ # if seq_len > self.max_seq_len_cached: # growth
184+ # inv_freq, self.attention_scaling = self.rope_init_fn(
185+ # self.config, device, seq_len=seq_len
186+ # )
187+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
188+ # self.max_seq_len_cached = seq_len
189+ #
190+ # So we should not need what follows.
191+ #
192+ # cond = (seq_len > self.max_seq_len_cached).item()
193+ # self.attention_scaling = torch.cond(
194+ # cond,
195+ # (lambda x, y: x.clone()),
196+ # (lambda x, y: y.clone()),
197+ # [attention_scaling, self.attention_scaling],
198+ # )
199+
200 seq_len = torch.max(position_ids) + 1
201+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
202+
203 if layer_type is None:
204- rope_type = self.rope_type
205- max_seq_len_cached = self.max_seq_len_cached
206+ # rope_type = self.rope_type
207+ # max_seq_len_cached = self.max_seq_len_cached
208 original_inv_freq = self.original_inv_freq
209 prefix = ""
210 else:
211- rope_type = self.rope_type[layer_type]
212- max_seq_len_cached = getattr(
213- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
214- )
215+ # rope_type = self.rope_type[layer_type]
216+ # max_seq_len_cached = getattr(
217+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
218+ # )
219 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
220 prefix = f"{layer_type}_"
221
222- if seq_len > max_seq_len_cached: # growth
223- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
224- inv_freq, self.attention_scaling = rope_init_fn(
225- self.config,
226- device,
227- seq_len=seq_len,
228- layer_type=layer_type,
229- )
230- # TODO joao: may break with compilation
231- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
232- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
233+ # Second test to translate.
234+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
235+ # But in that case the following condition is a way to restore the original cache.
236
237- if (
238- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
239- ): # reset
240- # This .to() is needed if the model has been moved to a device after being initialized (because
241- # the buffer is automatically moved, but not the original copy)
242- original_inv_freq = original_inv_freq.to(device)
243- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
244- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
245- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
246+ # if (
247+ # seq_len < self.original_max_seq_len
248+ # and self.max_seq_len_cached > self.original_max_seq_len
249+ # ):
250+ # self.original_inv_freq = self.original_inv_freq.to(device)
251+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
252+ # self.max_seq_len_cached = self.original_max_seq_len
253+
254+ original_inv_freq = self.original_inv_freq.to(device)
255+ cond = (seq_len >= self.original_max_seq_len).item()
256+ # PATCHED: uses torch.cond instead of a test
257+ inv_freq = torch.cond(
258+ cond,
259+ (lambda x, y: x.clone()),
260+ (lambda x, y: y.clone()),
261+ [long_inv_freq, original_inv_freq],
262+ )
263+ setattr(self, f"{prefix}inv_freq", inv_freq)
264
265 @wraps(rope_forward)
266 def wrapper(self, x, position_ids, layer_type=None):
267- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
268- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
269- if "dynamic" in rope_type:
270- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
271- elif rope_type == "longrope":
272- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
273- return rope_forward(self, x, position_ids, **kwargs)
274+ if layer_type is None:
275+ if "dynamic" in self.rope_type:
276+ dynamic_frequency_update(self, position_ids, device=x.device)
277+ elif self.rope_type == "longrope":
278+ longrope_frequency_update(self, position_ids, device=x.device)
279+ return rope_forward(self, x, position_ids)
280+
281+ if "dynamic" in self.rope_type:
282+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
283+ elif self.rope_type == "longrope":
284+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
285+ return rope_forward(self, x, position_ids, layer_type=layer_type)
286
287 return wrapper
auto/patch_transformers: GemmaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation -> patched_GenerationMixin._cache_dependant_input_preparation¶
1--- original
2+++ rewritten
3@@ -3,23 +3,29 @@
4 input_ids: torch.LongTensor,
5 inputs_embeds: Optional[torch.FloatTensor],
6 cache_position: Optional[torch.LongTensor],
7-) -> tuple[torch.FloatTensor, torch.LongTensor]:
8+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
9 """
10 Generic cache-dependent input preparation
11 The code is put in a separate function to allow granular unit testing
12 as it needs a different implementation to be exportable.
13
14- If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
15- - Exception 1: when passing input_embeds, input_ids may be missing entries
16- - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
17- - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
18- - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
19- generate the first token for each sequence. Later use the generated Input ids for continuation.
20+ If we have cache: let's slice `input_ids` through `cache_position`,
21+ to keep only the unprocessed tokens
22+ - Exception 1: when passing input_embeds,
23+ input_ids may be missing entries
24+ - Exception 2: some generation methods do special slicing of input_ids,
25+ so we don't need to do it here
26+ - Exception 3: with synced GPUs cache_position may go out of bounds,
27+ but we only want dummy token in that case.
28+ - Exception 4: If input_embeds are passed then slice it through
29+ `cache_position`, to keep only the unprocessed tokens and
30+ generate the first token for each sequence.
31+ Later use the generated Input ids for continuation.
32
33 The current implementation does not rely on ``self`` and could be
34 a class method. It is left as a standard method to be easily rewritten.
35 """
36- if is_torchdynamo_exporting():
37+ if _is_torchdynamo_exporting():
38 return self._cache_dependant_input_preparation_exporting(
39 input_ids, inputs_embeds, cache_position
40 )
auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation_exporting -> patched_GenerationMixin._cache_dependant_input_preparation_exporting¶
1--- original
2+++ rewritten
3@@ -3,7 +3,7 @@
4 input_ids: torch.LongTensor,
5 inputs_embeds: Optional[torch.FloatTensor],
6 cache_position: Optional[torch.LongTensor],
7-) -> tuple[torch.FloatTensor, torch.LongTensor]:
8+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
9 """
10 This method implements method ``_cache_dependant_input_preparation``
11 with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
12@@ -21,22 +21,21 @@
13 # else:
14 # if input_ids.shape[1] != cache_position.shape[0]:
15 # input_ids = input_ids[:, cache_position]
16- # We need to clone the outputs to avoid aliasing.
17 def branch_1(inputs_embeds, cache_position):
18- return inputs_embeds[:, -cache_position.shape[0] :].clone()
19+ return inputs_embeds[:, -cache_position.shape[0] :]
20
21 def branch_2(input_ids, cache_position):
22- return input_ids[:, -cache_position.shape[0] :].clone()
23+ return input_ids[:, -cache_position.shape[0] :]
24
25 def branch_3(input_ids, cache_position):
26- return input_ids[:, cache_position].clone()
27+ return input_ids[:, cache_position]
28
29 inputs_embeds, input_ids = torch.cond(
30 input_ids.shape[1] == 0,
31 (
32 lambda input_ids, inputs_embeds, cache_position: (
33 branch_1(inputs_embeds, cache_position),
34- input_ids.clone(),
35+ input_ids,
36 )
37 ),
38 (
39@@ -49,7 +48,7 @@
40 torch.cond(
41 input_ids.shape[1] != cache_position.shape[0],
42 branch_3,
43- (lambda input_ids, cache_position: input_ids.clone()),
44+ (lambda input_ids, cache_position: input_ids),
45 [input_ids, cache_position],
46 )
47 ),
auto/patch_transformers: IdeficsAttention.forward -> patched_IdeficsAttention.forward¶
1--- original
2+++ rewritten
3@@ -4,10 +4,12 @@
4 key_value_states: Optional[torch.Tensor] = None,
5 attention_mask: Optional[torch.Tensor] = None,
6 position_ids: Optional[torch.LongTensor] = None,
7- past_key_values: Optional[Cache] = None,
8+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
9+ output_attentions: bool = False,
10+ use_cache: bool = False,
11 cache_position: Optional[torch.LongTensor] = None,
12- **kwargs: Unpack[TransformersKwargs],
13-) -> tuple[torch.Tensor, torch.Tensor]:
14+ **kwargs,
15+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
16 # if key_value_states are provided this layer is used as a cross-attention layer
17 is_cross_attention = self.is_cross_attention or key_value_states is not None
18
19@@ -43,20 +45,27 @@
20 )
21
22 kv_seq_len = key_states.shape[-2]
23- if past_key_values is not None:
24+ if past_key_value is not None:
25 kv_seq_len += cache_position[0]
26
27 if not is_cross_attention:
28- cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
29- query_states, key_states = apply_rotary_pos_emb(
30- query_states, key_states, cos, sin, position_ids
31+ rotary_length = torch.maximum(
32+ torch.tensor(kv_seq_len, dtype=torch.int64),
33+ torch.tensor(q_len, dtype=torch.int64),
34+ )
35+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
36+ query_states, key_states = (
37+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
38+ query_states, key_states, cos, sin, position_ids
39+ )
40 )
41 # [bsz, nh, t, hd]
42
43- if past_key_values is not None:
44- # sin and cos are specific to RoPE models; cache_position needed for the static cache
45+ if past_key_value is not None:
46+ # sin and cos are specific to RoPE models;
47+ # cache_position needed for the static cache
48 cache_kwargs = {"cache_position": cache_position}
49- key_states, value_states = past_key_values.update(
50+ key_states, value_states = past_key_value.update(
51 key_states, value_states, self.layer_idx, cache_kwargs
52 )
53
54@@ -64,10 +73,22 @@
55 query_states = self.q_layer_norm(query_states)
56 key_states = self.k_layer_norm(key_states)
57
58- attention_interface: Callable = eager_attention_forward
59+ attention_interface: Callable = (
60+ transformers.models.idefics.modeling_idefics.eager_attention_forward
61+ )
62
63 if self.config._attn_implementation != "eager":
64- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
65+ if self.config._attn_implementation == "sdpa" and output_attentions:
66+ transformers.models.idefics.modeling_idefics.logger.warning_once(
67+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
68+ "`output_attentions=True`. Falling back to "
69+ "eager attention. This warning can be removed using the argument "
70+ '`attn_implementation="eager"` when loading the model.'
71+ )
72+ else:
73+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
74+ self.config._attn_implementation
75+ ]
76
77 attn_output, attn_weights = attention_interface(
78 self,
79@@ -83,4 +104,9 @@
80 attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
81 attn_output = self.o_proj(attn_output)
82
83+ if output_attentions:
84+ attn_weights = None
85+
86+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
87+ return attn_output, attn_weights, past_key_value
88 return attn_output, attn_weights
auto/patch_transformers: IdeficsEmbedding.forward -> patched_IdeficsEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,9 +1,26 @@
4 def forward(self, x, seq_len=None):
5 # x: [bs, num_attention_heads, seq_len, head_size]
6- if seq_len > self.max_seq_len_cached:
7- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
8+ # if seq_len > self.max_seq_len_cached:
9+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
10
11- return (
12- self.cos_cached[:seq_len].to(dtype=x.dtype),
13- self.sin_cached[:seq_len].to(dtype=x.dtype),
14+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
15+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
16+ # freqs = torch.einsum("i,j->ij", t, inv_freq)
17+ freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
18+ emb = torch.cat((freqs, freqs), dim=-1)
19+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
20+
21+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
22+ torch._check(seq_len.item() <= cos_cached.shape[0])
23+ co = cos_cached[: seq_len.item()].detach().clone()
24+ torch._check(seq_len.item() <= sin_cached.shape[0])
25+ si = sin_cached[: seq_len.item()].detach().clone()
26+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
27+
28+ cos_cached, sin_cached = torch.cond(
29+ (seq_len > self.max_seq_len_cached).item(),
30+ _set_cos_sin_cache_then,
31+ _set_cos_sin_cache_else,
32+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
33 )
34+ return cos_cached, sin_cached
auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: MistralRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: MixtralRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Phi3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Phi4MultimodalRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: PhiRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Qwen3MoeSparseMoeBlock.forward -> patched_Qwen3MoeSparseMoeBlock.forward¶
1--- original
2+++ rewritten
3@@ -1,9 +1,67 @@
4-def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
5+def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6+ """ """
7 batch_size, sequence_length, hidden_dim = hidden_states.shape
8- hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
9- router_logits = self.gate(hidden_states_reshaped)
10- selected_experts, routing_weights = self.route_tokens_to_experts(
11- hidden_states_reshaped, router_logits
12+ hidden_states = hidden_states.view(-1, hidden_dim)
13+ # router_logits: (batch * sequence_length, n_experts)
14+ router_logits = self.gate(hidden_states)
15+
16+ routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
17+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
18+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
19+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
20+ # we cast back to the input dtype
21+ routing_weights = routing_weights.to(hidden_states.dtype)
22+
23+ final_hidden_states = torch.zeros(
24+ (batch_size * sequence_length, hidden_dim),
25+ dtype=hidden_states.dtype,
26+ device=hidden_states.device,
27 )
28- final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
29- return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
30+
31+ # One hot encode the selected experts to create an expert mask
32+ # this will be used to easily index which expert is going to be sollicitated
33+ expert_mask = torch.nn.functional.one_hot(
34+ selected_experts, num_classes=self.num_experts
35+ ).permute(2, 1, 0)
36+
37+ # Loop over all available experts in the model
38+ # and perform the computation on each expert
39+ expert_sum = expert_mask.sum(dim=(-1, -2))
40+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
41+ # for expert_idx in expert_hit:
42+ for expert_idx in range(self.num_experts):
43+ # initial code has a squeeze but it is not possible to do that.
44+ # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
45+ expert_mask_idx = expert_mask[expert_idx]
46+ final_hidden_states = torch.cond(
47+ (expert_sum[expert_idx] > 0).item(),
48+ lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
49+ final_hidden_states,
50+ expert_mask,
51+ hidden_states,
52+ routing_weights,
53+ expert_idx=_i,
54+ ),
55+ lambda final_hidden_states, *args: final_hidden_states.clone(),
56+ [final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
57+ )
58+
59+ # if expert_sum[expert_idx] > 0:
60+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
61+
62+ # Index the correct hidden states and compute the expert hidden state for
63+ # the current expert. We need to make sure to multiply the output hidden
64+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
65+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
66+ # current_hidden_states = (
67+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
68+ # )
69+
70+ # However `index_add_` only support torch tensors for indexing so we'll use
71+ # the `top_x` tensor here.
72+ # final_hidden_states.index_add_(
73+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
74+ # )
75+
76+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
77+ return final_hidden_states, router_logits
auto/patch_transformers: ‘Qwen3MoeSparseMoeBlock_forward_expert_loop’ -> patched_Qwen3MoeSparseMoeBlock._forward_expert_loop¶
1def _forward_expert_loop(
2 self,
3 final_hidden_states,
4 expert_mask_idx,
5 hidden_states,
6 routing_weights,
7 expert_idx: int,
8):
9 # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
10 idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
11 hidden_dim = hidden_states.shape[-1]
12 current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
13 expert_current_state = self.experts[expert_idx](current_state)
14 current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
15 return final_hidden_states.index_add(0, top_x, current_hidden_states.to(hidden_states.dtype))
auto/patch_transformers: SamMaskDecoder.forward -> patched_SamMaskDecoder.forward¶
1--- original
2+++ rewritten
3@@ -5,6 +5,7 @@
4 sparse_prompt_embeddings: torch.Tensor,
5 dense_prompt_embeddings: torch.Tensor,
6 multimask_output: bool,
7+ output_attentions: Optional[bool] = None,
8 attention_similarity: Optional[torch.Tensor] = None,
9 target_embedding: Optional[torch.Tensor] = None,
10 ) -> tuple[torch.Tensor, torch.Tensor]:
11@@ -22,19 +23,31 @@
12 the embeddings of the mask inputs
13 multimask_output (bool):
14 Whether to return multiple masks or a single mask.
15+ output_attentions (bool, *optional*):
16+ Whether or not to return the attentions tensors of all attention layers.
17 """
18 batch_size, num_channels, height, width = image_embeddings.shape
19- point_batch_size = (
20- sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
21- )
22+ point_batch_size = sparse_prompt_embeddings.shape[1]
23 # Concatenate output tokens
24 output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
25 output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
26
27- if sparse_prompt_embeddings is not None:
28- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
29- else:
30- tokens = output_tokens
31+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
32+ # torch.any is needed to avoid data-dependent control flow
33+ # with sparse_prompt_embeddings.sum().item() != 0
34+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
35+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
36+
37+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
38+ return output_tokens.clone()
39+
40+ tokens = torch.cond(
41+ torch.any(sparse_prompt_embeddings != 0),
42+ sparse_prompt_embeddings_is_not_empty,
43+ sparse_prompt_embeddings_is_empty,
44+ [output_tokens, sparse_prompt_embeddings],
45+ )
46+
47 point_embeddings = tokens.to(self.iou_token.weight.dtype)
48
49 # Expand per-image data in batch direction to be per-point
50@@ -45,15 +58,21 @@
51 )
52
53 # Run the transformer, image_positional_embedding are consumed
54- point_embedding, image_embeddings = self.transformer(
55+ torch._check(point_embeddings.shape[0] != 0)
56+ torch._check(point_embeddings.shape[1] != 0)
57+ torch._check(point_embeddings.shape[2] != 0)
58+ torch._check(point_embeddings.shape[3] != 0)
59+ embeddings_attentions = self.transformer(
60 point_embeddings=point_embeddings,
61 image_embeddings=image_embeddings,
62 image_positional_embeddings=image_positional_embeddings,
63 attention_similarity=attention_similarity,
64 target_embedding=target_embedding,
65+ output_attentions=output_attentions,
66 )
67- iou_token_out = point_embedding[:, :, 0, :]
68- mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
69+ point_embedding, image_embeddings = embeddings_attentions[:2]
70+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
71+ mask_tokens_out = torch.narrow(point_embedding, dim=2, start=1, length=self.num_mask_tokens)
72
73 # Upscale mask embeddings and predict masks using the mask tokens
74 image_embeddings = image_embeddings.transpose(2, 3).reshape(
75@@ -88,4 +107,15 @@
76 mask_slice = slice(0, 1)
77 masks = masks[:, :, mask_slice, :, :]
78 iou_pred = iou_pred[:, :, mask_slice]
79- return masks, iou_pred
80+
81+ outputs = (masks, iou_pred)
82+
83+ if len(embeddings_attentions) == 2:
84+ # transformers==4.54
85+ return outputs
86+
87+ if output_attentions and len(embeddings_attentions) > 2:
88+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
89+ else:
90+ outputs = outputs + (None,) # noqa: RUF005
91+ return outputs
auto/patch_transformers: SmolLM3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: VisionAttention.forward -> patched_VisionAttention.forward¶
1--- original
2+++ rewritten
3@@ -3,69 +3,55 @@
4 hidden_states: torch.Tensor,
5 cu_seqlens: torch.Tensor,
6 rotary_pos_emb: Optional[torch.Tensor] = None,
7- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
8- **kwargs,
9+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
10 ) -> torch.Tensor:
11 seq_length = hidden_states.shape[0]
12- query_states, key_states, value_states = (
13+ q, k, v = (
14 self.qkv(hidden_states)
15 .reshape(seq_length, 3, self.num_heads, -1)
16 .permute(1, 0, 2, 3)
17 .unbind(0)
18 )
19- cos, sin = position_embeddings
20- query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
21+ if position_embeddings is None:
22+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
23+ "The attention layers in this model are transitioning from "
24+ " computing the RoPE embeddings internally "
25+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
26+ "to using externally computed "
27+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
28+ " In v4.54 `rotary_pos_emb` will be "
29+ "removed and `position_embeddings` will be mandatory."
30+ )
31+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
32+ cos = emb.cos()
33+ sin = emb.sin()
34+ else:
35+ cos, sin = position_embeddings
36+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
37+ q, k, cos, sin
38+ )
39
40- query_states = query_states.transpose(0, 1).unsqueeze(0)
41- key_states = key_states.transpose(0, 1).unsqueeze(0)
42- value_states = value_states.transpose(0, 1).unsqueeze(0)
43+ attention_mask = torch.full(
44+ [1, seq_length, seq_length],
45+ torch.finfo(q.dtype).min,
46+ device=q.device,
47+ dtype=q.dtype,
48+ )
49+ # for i in range(1, len(cu_seqlens)):
50+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
51+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
52+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
53
54- attention_interface: Callable = eager_attention_forward
55- if self.config._attn_implementation != "eager":
56- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
57-
58- if self.config._attn_implementation == "flash_attention_2":
59- # Flash Attention 2: Use cu_seqlens for variable length attention
60- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
61- attn_output, _ = attention_interface(
62- self,
63- query_states,
64- key_states,
65- value_states,
66- attention_mask=None,
67- scaling=self.scaling,
68- dropout=0.0 if not self.training else self.attention_dropout,
69- cu_seq_lens_q=cu_seqlens,
70- cu_seq_lens_k=cu_seqlens,
71- max_length_q=max_seqlen,
72- max_length_k=max_seqlen,
73- is_causal=False,
74- **kwargs,
75- )
76- else:
77- # Other implementations: Process each chunk separately
78- lengths = cu_seqlens[1:] - cu_seqlens[:-1]
79- splits = [
80- torch.split(tensor, lengths.tolist(), dim=2)
81- for tensor in (query_states, key_states, value_states)
82- ]
83-
84- attn_outputs = [
85- attention_interface(
86- self,
87- q,
88- k,
89- v,
90- attention_mask=None,
91- scaling=self.scaling,
92- dropout=0.0 if not self.training else self.attention_dropout,
93- is_causal=False,
94- **kwargs,
95- )[0]
96- for q, k, v in zip(*splits)
97- ]
98- attn_output = torch.cat(attn_outputs, dim=1)
99-
100- attn_output = attn_output.reshape(seq_length, -1).contiguous()
101+ q = q.transpose(0, 1)
102+ k = k.transpose(0, 1)
103+ v = v.transpose(0, 1)
104+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
105+ attn_weights = attn_weights + attention_mask
106+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
107+ q.dtype
108+ )
109+ attn_output = torch.matmul(attn_weights, v)
110+ attn_output = attn_output.transpose(0, 1)
111+ attn_output = attn_output.reshape(seq_length, -1)
112 attn_output = self.proj(attn_output)
113 return attn_output
auto/patch_transformers: eager_attention_forward -> patched_model_bart_eager_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,27 +1,23 @@
4-def eager_attention_forward(
5- module: nn.Module,
6+def patched_model_bart_eager_attention_forward(
7+ module: torch.nn.Module,
8 query: torch.Tensor,
9 key: torch.Tensor,
10 value: torch.Tensor,
11 attention_mask: Optional[torch.Tensor],
12 scaling: Optional[float] = None,
13 dropout: float = 0.0,
14- **kwargs: Unpack[TransformersKwargs],
15+ head_mask: Optional[torch.Tensor] = None,
16+ **kwargs,
17 ):
18- if scaling is None:
19- scaling = query.size(-1) ** -0.5
20-
21- # Take the dot product between "query" and "key" to get the raw attention scores.
22- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24- if attention_mask is not None:
25- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26- attn_weights = attn_weights + attention_mask
27-
28- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31- attn_output = torch.matmul(attn_weights, value)
32- attn_output = attn_output.transpose(1, 2).contiguous()
33-
34- return attn_output, attn_weights
35+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
36+ return common_eager_attention_forward(
37+ module,
38+ query,
39+ key,
40+ value,
41+ attention_mask=attention_mask,
42+ scaling=scaling,
43+ dropout=dropout,
44+ head_mask=head_mask,
45+ **kwargs,
46+ )
auto/patch_transformers: eager_attention_forward -> patched_modeling_marian_eager_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,27 +1,23 @@
4-def eager_attention_forward(
5- module: nn.Module,
6+def patched_modeling_marian_eager_attention_forward(
7+ module: torch.nn.Module,
8 query: torch.Tensor,
9 key: torch.Tensor,
10 value: torch.Tensor,
11 attention_mask: Optional[torch.Tensor],
12 scaling: Optional[float] = None,
13 dropout: float = 0.0,
14- **kwargs: Unpack[TransformersKwargs],
15+ head_mask: Optional[torch.Tensor] = None,
16+ **kwargs,
17 ):
18- if scaling is None:
19- scaling = query.size(-1) ** -0.5
20-
21- # Take the dot product between "query" and "key" to get the raw attention scores.
22- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24- if attention_mask is not None:
25- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26- attn_weights = attn_weights + attention_mask
27-
28- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31- attn_output = torch.matmul(attn_weights, value)
32- attn_output = attn_output.transpose(1, 2).contiguous()
33-
34- return attn_output, attn_weights
35+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
36+ return common_eager_attention_forward(
37+ module,
38+ query,
39+ key,
40+ value,
41+ attention_mask=attention_mask,
42+ scaling=scaling,
43+ dropout=dropout,
44+ head_mask=head_mask,
45+ **kwargs,
46+ )