Patches Diff¶
Patches are not always needed to export a LLM.
Most of the time, only serialization function are needed to export
a LLM with cache (DynamicCache, …).
Function register_additional_serialization_functions
is enough in many cases.
import torch
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
with register_additional_serialization_functions(patch_transformers=True):
ep = torch.export.export(...)
Function torch_export_patches
helps fixing some issues for many models.
import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(...)
Class PatchDetails
gives an example on how to retrieve the list of involded patches for a specific model.
Those patches belongs to the following list which depends on transformers and
pytorch versions.
<<<
import torch
import transformers
print(torch.__version__, transformers.__version__)
>>>
2.10.0.dev20251113+cu130 5.0.0.dev0
Those two versions leads to the following list of patches.
<<<
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_export_patches import torch_export_patches
details = PatchDetails()
with torch_export_patches(
patch_transformers=True,
patch_torch=True,
patch_diffusers=True,
patch_details=details,
):
pass
for patch in details.patched:
if patch.function_to_patch == patch.patch:
continue
rst = patch.format_diff(format="rst")
print()
print()
print(rst)
print()
print()
>>>
sympy: ‘sympy.core.numbers.IntegerConstant.name’ -> _patch_sympy.<locals>.<lambda>¶
1sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
torch: infer_size -> patched_infer_size¶
1--- original
2+++ rewritten
3@@ -1,4 +1,5 @@
4-def infer_size(a, b):
5+def patched_infer_size(a, b):
6+ """Patches ``torch._subclasses.fake_impls.infer_size``."""
7 from torch.fx.experimental.symbolic_shapes import guard_or_false
8
9 dimsA = len(a)
10@@ -23,11 +24,21 @@
11 # expression of an or statement as-is, without bool()'ing it; if this
12 # were not the case, we'd need to write this using torch.sym_or() or
13 # something like that).
14- torch._check(
15- guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
16- lambda: f"The size of tensor a ({sizeA}) "
17- f"must match the size of tensor b ({sizeB}) "
18- f"at non-singleton dimension {i})",
19- )
20- expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
21+ try:
22+ b1 = guard_or_false(sizeA == 1)
23+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
24+ b1 = False
25+ try:
26+ b2 = guard_or_false(sizeB == 1)
27+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
28+ b2 = False
29+ try:
30+ b3 = guard_or_false(sizeA == sizeB)
31+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
32+ b3 = False
33+ if b1 or b2 or b3:
34+ expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
35+ else:
36+ # PATCHED: generic case, the dimension is known, no need to assert
37+ expandedSizes[i] = torch.sym_max(sizeA, sizeB)
38 return tuple(expandedSizes)
torch: _broadcast_shapes -> patched__broadcast_shapes¶
1--- original
2+++ rewritten
3@@ -1,11 +1,11 @@
4-def _broadcast_shapes(*_shapes):
5+def patched__broadcast_shapes(*_shapes):
6+ """Patches ``torch._refs._broadcast_shapes``."""
7+ from functools import reduce
8+ from torch._prims_common import IntLike
9 from torch.fx.experimental.symbolic_shapes import (
10 guard_or_false,
11 is_nested_int,
12- size_hint,
13 )
14-
15- backed_so = torch.fx.experimental._config.backed_size_oblivious
16
17 shapes = tuple(
18 (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
19@@ -18,17 +18,15 @@
20 for shape in shapes:
21 if not isinstance(shape, Sequence):
22 raise RuntimeError(
23- "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
24+ "Input shapes should be of type ints, a tuple of ints, "
25+ "or a list of ints, got ",
26 shape,
27 )
28
29 # Computes common shape
30- common_shape: list[Union[int, torch.SymInt]] = [
31- 1,
32- ] * reduce(max, (len(shape) for shape in shapes))
33- for arg_idx, shape in enumerate(shapes):
34+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
35+ for _arg_idx, shape in enumerate(shapes):
36 for idx in range(-1, -1 - len(shape), -1):
37- # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
38 if is_nested_int(shape[idx]):
39 # Broadcasting is allowed for (j0, 1) or (j0, j0);
40 # not (j0, j1), (j0, 5), etc.
41@@ -37,40 +35,17 @@
42 ):
43 continue
44 else:
45- # When backed size oblivious is used, we specialize for broadcasting
46- # if its the only way to compile the example input.
47- # i.e: s0:1, s1:1 ==>
48- # assert s0==s1, no specialization on ==1 or !=1.
49- # The non-broadcast path is picked
50- # s0:1, s1:4 ==>
51- # specialize(s0) to be 1.
52- # s0:4, s1:1 ==>
53- # specialize(s1) to be 1.
54- if backed_so:
55- a = size_hint(shape[idx], allow_none=True)
56- b = size_hint(common_shape[idx], allow_none=True)
57- if a == 1 and b != 1:
58- torch._check(shape[idx] == 1)
59- if b == 1 and a != 1:
60- torch._check(common_shape[idx] == 1)
61 if guard_or_false(shape[idx] == common_shape[idx]):
62 continue
63-
64- if guard_or_false(common_shape[idx] == 1):
65+ # PATCHED: two cases, if == for sure, no broadcast,
66+ # otherwise maybe broadcast with max(dimensions)
67+ if guard_or_false(common_shape[idx] != 1):
68+ pass
69+ elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
70 if shape[idx] < 0:
71 raise ValueError("Attempting to broadcast a dimension with negative length!")
72 common_shape[idx] = shape[idx]
73-
74- if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
75- # broadcast case .
76- continue
77 else:
78- # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
79- torch._check(
80- common_shape[idx] == shape[idx],
81- lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
82- f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
83- f"should be broadcastable to {common_shape}",
84- )
85+ common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
86
87 return common_shape
torch: _constrain_user_specified_dimhint_range -> patched__constrain_user_specified_dimhint_range¶
1--- original
2+++ rewritten
3@@ -1,28 +1,31 @@
4-def _constrain_user_specified_dimhint_range(
5+def patched__constrain_user_specified_dimhint_range(
6 symint: torch.SymInt,
7 hint: int,
8- dim: _DimHint,
9+ dim: "_DimHint", # noqa: F821
10 range_constraints,
11 shape_env,
12- keypath: KeyPath,
13+ keypath: "KeyPath", # noqa: F821
14 i: Optional[int] = None,
15 ) -> Optional[str]:
16+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
17+ from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
18+
19 trace_vr = (
20 range_constraints[symint.node.expr]
21 if not is_int(symint)
22 else ValueRanges(int(symint), int(symint))
23 )
24-
25 # warn on 0/1 specialization for Dim.AUTO; not an actual error
26- if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
27- pathstr = f"inputs{pytree.keystr(keypath)}"
28- if i is not None:
29- pathstr += f".shape[{i}]"
30- msg = (
31- f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
32- + f"with a sample input with hint = {hint}."
33- )
34- log.warning(msg)
35+ # PATCHED: remove logging
36+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
37+ # pathstr = f"inputs{pytree.keystr(keypath)}"
38+ # if i is not None:
39+ # pathstr += f".shape[{i}]"
40+ # msg = (
41+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
42+ # f"with a sample input with hint = {hint}."
43+ # )
44+ # log.warning(msg)
45
46 try:
47 user_vr = ValueRanges(
48@@ -38,32 +41,40 @@
49
50 # check for Dim.DYNAMIC specializations; special case error message on 0/1
51 if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
52- path = f"inputs{pytree.keystr(keypath)}"
53+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
54 if i is not None:
55 path += f".shape[{i}]"
56 if (
57 trace_vr.is_singleton()
58 and hint in (0, 1)
59- and not torch.fx.experimental._config.backed_size_oblivious
60+ # PATCHED: line removed
61+ # and not torch.fx.experimental._config.backed_size_oblivious
62 ):
63- msg = (
64- f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
65- f"but export 0/1 specialized due to hint of {hint} for dimension {path}."
66- )
67+ return None
68+ # PATCHED: line removed
69+ # msg = (
70+ # f"- Received user-specified dim hint "
71+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
72+ # f"but export 0/1 specialized due to hint of "
73+ # f"{hint} for dimension {path}."
74+ # )
75 else:
76 msg = (
77- f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
78- f"but tracing inferred a static shape of {out_vr.lower} for dimension {path}."
79+ f"- Received user-specified dim hint "
80+ f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
81+ f"but tracing inferred a static shape of "
82+ f"{out_vr.lower} for dimension {path}."
83 )
84 return msg
85
86 except torch.utils._sympy.value_ranges.ValueRangeError:
87- path = f"inputs{pytree.keystr(keypath)}"
88+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
89 if i is not None:
90 path += f".shape[{i}]"
91 msg = (
92 f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
93- f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], "
94+ f"conflicting with the inferred min/max range of "
95+ f"[{trace_vr.lower}, {trace_vr.upper}], "
96 f"for {path}."
97 )
98 return msg
torch: _broadcast_in_dim_meta -> patched__broadcast_in_dim_meta¶
1--- original
2+++ rewritten
3@@ -1,6 +1,9 @@
4-def _broadcast_in_dim_meta(
5- a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
6+def patched__broadcast_in_dim_meta(
7+ a: torch._prims_common.TensorLikeType,
8+ shape: torch._prims_common.ShapeType,
9+ broadcast_dimensions: Sequence[int],
10 ):
11+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
12 from torch.fx.experimental.symbolic_shapes import (
13 guard_or_false,
14 guard_or_true,
15@@ -8,7 +11,7 @@
16 )
17
18 # Type checks
19- assert isinstance(a, TensorLike)
20+ assert isinstance(a, torch._prims_common.TensorLike)
21 assert isinstance(shape, Sequence)
22 assert isinstance(broadcast_dimensions, Sequence)
23
24@@ -22,7 +25,7 @@
25 # (no relative reordering of dims) of integers and
26 # each dimension must be within the new shape
27 def _greater_than_reduce(acc, x):
28- assert isinstance(x, Dim)
29+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
30 assert x > acc
31 assert x < len(shape)
32
33@@ -34,7 +37,9 @@
34 for idx, new_idx in enumerate(broadcast_dimensions):
35 torch._check(
36 sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
37- lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
38+ lambda idx=idx, new_idx=new_idx: (
39+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
40+ ),
41 )
42
43 new_strides = []
44@@ -48,10 +53,26 @@
45 new_strides.append(a.stride()[original_idx])
46 else:
47 new_strides.append(0)
48+ # PATCHED: disabled this check
49+ elif guard_or_false(a.shape[original_idx] != 1):
50+ new_strides.append(a.stride()[original_idx])
51 else:
52+ # This checks generates the following issue:
53+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
54+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
55+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
56+ # original_idx=1
57 torch._check(
58 a.shape[original_idx] == shape[idx],
59- lambda: f"non-broadcasting semantics require {a.shape[original_idx]} == {shape[idx]}",
60+ lambda idx=idx, original_idx=original_idx: (
61+ f"non-broadcasting semantics require "
62+ f"{a.shape[original_idx]} == {shape[idx]}, "
63+ f"{guard_or_false(a.shape[idx] != 1)}, "
64+ f"guard_or_false(a.shape[idx]==1)="
65+ f"{guard_or_false(a.shape[idx] == 1)}, "
66+ f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
67+ f"shape={shape}, original_idx={original_idx}"
68+ ),
69 )
70 new_strides.append(a.stride()[original_idx])
71 original_idx = original_idx + 1
torch: _maybe_broadcast -> patched__maybe_broadcast¶
1--- original
2+++ rewritten
3@@ -1,6 +1,9 @@
4-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
5+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
6+ """Patches ``torch._refs._maybe_broadcast``."""
7+ from torch._prims_common import ShapeType, TensorLike, Number
8+
9 # Computes common shape
10- common_shape = _broadcast_shapes(
11+ common_shape = patched__broadcast_shapes(
12 *(t.shape if isinstance(t, TensorLike) else None for t in args)
13 )
14
15@@ -29,10 +32,15 @@
16 return True
17
18 # u0==u1 assume the same, no broadcasting!
19- torch._check(
20- x == y,
21- lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
22- )
23+ # PATCHED: avoid errors
24+ return True # guard_or_true(x != y)
25+ # torch._check(
26+ # x == y,
27+ # lambda x=x, y=y: (
28+ # f"sizes assumed to be the same due to unbacked "
29+ # f"broadcasting semantics x={x!r}, y={y!r}"
30+ # ),
31+ # )
32
33 return False
34
35@@ -42,7 +50,7 @@
36 elif isinstance(x, Number):
37 return x
38 elif isinstance(x, TensorLike):
39- if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
40+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
41 return x
42
43 if should_expand(x.shape, common_shape):
44@@ -50,6 +58,6 @@
45
46 return x
47 else:
48- raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
49+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
50
51 return tuple(__maybe_broadcast(x, common_shape) for x in args)
torch: ShapeEnv._evaluate_expr -> patched_ShapeEnv._evaluate_expr¶
1--- original
2+++ rewritten
3@@ -1,14 +1,24 @@
4 def _evaluate_expr(
5 self,
6- orig_expr: sympy.Basic,
7+ orig_expr: "sympy.Basic", # noqa: F821
8 hint: Optional[Union[bool, int, float]] = None,
9 fx_node: Optional[torch.fx.Node] = None,
10 size_oblivious: bool = False,
11 fallback_value: Optional[bool] = None,
12 *,
13 forcing_spec: bool = False,
14-) -> sympy.Basic:
15+) -> "sympy.Basic": # noqa: F821
16 # TODO: split conjunctions and evaluate them separately
17+ import sympy
18+ from torch.fx.experimental import _config as config
19+ from torch.fx.experimental.symbolic_shapes import (
20+ SympyBoolean,
21+ log,
22+ SymT,
23+ symbol_is_type,
24+ )
25+ from torch._guards import ShapeGuard
26+
27 if isinstance(
28 orig_expr,
29 (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
30@@ -118,7 +128,8 @@
31 self._log_suppressed_dde(orig_expr, fallback_value)
32 return fallback_value
33
34- # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type.
35+ # oblivious_var_to_val will be defined iff we have sizes
36+ # with DimDynamic.OBLIVIOUS_SIZE type.
37 # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
38 if (
39 self.oblivious_var_to_val
40@@ -136,15 +147,17 @@
41 log.info(
42 "oblivious_size %s -> %s (passed counterfactual)",
43 orig_expr,
44+ # pyrefly: ignore # unbound-name
45 correct_hint,
46 )
47-
48+ # pyrefly: ignore # unbound-name
49 concrete_val = correct_hint
50 # NB: do NOT transmute into runtime assert
51 ok = True
52
53 # unbacked_var_to_val is not None iff propagate_real_tensors is on.
54- # if propagate_real_tensors is on, we check the example values to generate (unsound_result)
55+ # if propagate_real_tensors is on, we check the example values
56+ # to generate (unsound_result)
57 # and if they pass we add a runtime assertions and continue.
58 if (
59 not ok
60@@ -155,25 +168,29 @@
61 )
62 ).free_symbols
63 ):
64+ # pyrefly: ignore # unbound-name
65 self._log_real_tensor_propagation(orig_expr, unsound_result)
66 transmute_into_runtime_assert = True
67-
68+ # pyrefly: ignore # unbound-name
69 concrete_val = unsound_result
70 ok = True
71
72- # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion
73+ # Check if this is coming from a python assert statement,
74+ # if so, convert it to a runtime assertion
75 # instead of failing.
76 if not ok and self.trace_asserts and self._is_python_assert():
77 concrete_val = sympy.true
78 transmute_into_runtime_assert = True
79 ok = True
80
81- if not ok:
82- raise self._make_data_dependent_error(
83- expr.xreplace(self.var_to_val),
84- expr,
85- expr_sym_node_id=self._expr_sym_node_id,
86- )
87+ # PATCHED: ok -> True
88+ ok = True
89+ # if not ok:
90+ # raise self._make_data_dependent_error(
91+ # expr.xreplace(self.var_to_val),
92+ # expr,
93+ # expr_sym_node_id=self._expr_sym_node_id,
94+ # )
95 else:
96 expr = new_expr
patch_transformers: dynamic_rope_update -> patched_dynamic_rope_update¶
1--- original
2+++ rewritten
3@@ -1,99 +1,193 @@
4-def dynamic_rope_update(rope_forward):
5- """
6- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
7- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
8+def patched_dynamic_rope_update(rope_forward):
9+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
10
11- Args:
12- rope_forward (Callable):
13- The forward pass of the RoPE implementation.
14+ ``rope_type`` is determined in the constructor of class
15+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
16
17- Returns:
18- The decorated forward pass.
19+ .. code-block:: python
20+
21+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
22+ self.rope_type = config.rope_scaling.get(
23+ "rope_type", config.rope_scaling.get("type"))
24+ else:
25+ self.rope_type = "default"
26+
27+ The original code of the patched function:
28+
29+ .. code-block:: python
30+
31+ def dynamic_rope_update(rope_forward):
32+ def longrope_frequency_update(self, position_ids, device):
33+ seq_len = torch.max(position_ids) + 1
34+ if hasattr(self.config, "original_max_position_embeddings"):
35+ original_max_position_embeddings =
36+ self.config.original_max_position_embeddings
37+ else:
38+ original_max_position_embeddings =
39+ self.config.max_position_embeddings
40+ if seq_len > original_max_position_embeddings:
41+ if not hasattr(self, "long_inv_freq"):
42+ self.long_inv_freq, _ = self.rope_init_fn(
43+ self.config, device, seq_len=original_max_position_embeddings + 1
44+ )
45+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
46+ else:
47+ self.original_inv_freq = self.original_inv_freq.to(device)
48+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
49+
50+ def dynamic_frequency_update(self, position_ids, device):
51+ seq_len = torch.max(position_ids) + 1
52+ if seq_len > self.max_seq_len_cached: # growth
53+ inv_freq, self.attention_scaling = self.rope_init_fn(
54+ self.config, device, seq_len=seq_len)
55+ self.register_buffer("inv_freq", inv_freq, persistent=False)
56+ self.max_seq_len_cached = seq_len
57+
58+ if seq_len < self.original_max_seq_len and
59+ self.max_seq_len_cached > self.original_max_seq_len:
60+ self.original_inv_freq = self.original_inv_freq.to(device)
61+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
62+ self.max_seq_len_cached = self.original_max_seq_len
63+
64+ @wraps(rope_forward)
65+ def wrapper(self, x, position_ids):
66+ if "dynamic" in self.rope_type:
67+ dynamic_frequency_update(self, position_ids, device=x.device)
68+ elif self.rope_type == "longrope":
69+ longrope_frequency_update(self, position_ids, device=x.device)
70+ return rope_forward(self, x, position_ids)
71+
72+ return wrapper
73+
74 """
75
76 def longrope_frequency_update(self, position_ids, device, layer_type=None):
77- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
78+ # It is no use to patch the function after the model is created
79+ # as rope_init_fn is an attribute set to one function when the model
80+ # is created and when no patch is applied yet.
81+ # So we select the patched version here.
82+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
83 seq_len = torch.max(position_ids) + 1
84- original_max_position_embeddings = getattr(
85- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
86- )
87+ if hasattr(self.config, "original_max_position_embeddings"):
88+ original_max_position_embeddings = self.config.original_max_position_embeddings
89+ else:
90+ original_max_position_embeddings = self.config.max_position_embeddings
91+
92 if layer_type is None:
93- rope_type = self.rope_type
94+ # rope_type = self.rope_type
95 original_inv_freq = self.original_inv_freq
96 prefix = ""
97 else:
98- rope_type = self.rope_type[layer_type]
99+ # rope_type = self.rope_type[layer_type]
100 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
101 prefix = f"{layer_type}_"
102
103- if seq_len > original_max_position_embeddings:
104- if not hasattr(self, f"{layer_type}_long_inv_freq"):
105- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
106- long_inv_freq, _ = rope_init_fn(
107- self.config,
108- device,
109- seq_len=original_max_position_embeddings + 1,
110- layer_type=layer_type,
111- )
112- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
113- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
114- else:
115- # This .to() is needed if the model has been moved to a device after being initialized (because
116- # the buffer is automatically moved, but not the original copy)
117- original_inv_freq = original_inv_freq.to(device)
118- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
119- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
120+ # At export time, seq_len is unknown.
121+ long_inv_freq, _ = rope_init_fn(
122+ self.config, device, seq_len=original_max_position_embeddings + 1
123+ )
124+ original_inv_freq = self.original_inv_freq.to(device)
125+
126+ # PATCHED: uses torch.cond instead of a test
127+ cond = (seq_len > original_max_position_embeddings).item()
128+ inv_freq = torch.cond(
129+ cond,
130+ (lambda x, y: x.clone()),
131+ (lambda x, y: y.clone()),
132+ [long_inv_freq, original_inv_freq],
133+ )
134+ setattr(self, f"{prefix}inv_freq", inv_freq)
135+ # if seq_len > original_max_position_embeddings:
136+ # self.inv_freq = self.long_inv_freq
137+ # else:
138+ # self.inv_freq = self.original_inv_freq
139
140 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
141- """
142- dynamic RoPE layers should recompute `inv_freq` in the following situations:
143- 1 - growing beyond the cached sequence length (allow scaling)
144- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
145- """
146+ # constructor:
147+ # - self.max_seq_len_cached = config.max_position_embeddings
148+ # - self.original_max_seq_len = config.max_position_embeddings
149+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
150+
151+ # It is no use to patch the function after the model is created
152+ # as rope_init_fn is an attribute set to one function when the model
153+ # is created and when no patch is applied yet.
154+ # So we select the patched version here.
155+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
156+
157+ # This behaviour is difficult to translate.
158+ # The sequence always grows.
159+ # The test should always True.
160+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
161+ #
162+ # if seq_len > self.max_seq_len_cached: # growth
163+ # inv_freq, self.attention_scaling = self.rope_init_fn(
164+ # self.config, device, seq_len=seq_len
165+ # )
166+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
167+ # self.max_seq_len_cached = seq_len
168+ #
169+ # So we should not need what follows.
170+ #
171+ # cond = (seq_len > self.max_seq_len_cached).item()
172+ # self.attention_scaling = torch.cond(
173+ # cond,
174+ # (lambda x, y: x.clone()),
175+ # (lambda x, y: y.clone()),
176+ # [attention_scaling, self.attention_scaling],
177+ # )
178+
179 seq_len = torch.max(position_ids) + 1
180+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
181+
182 if layer_type is None:
183- rope_type = self.rope_type
184- max_seq_len_cached = self.max_seq_len_cached
185+ # rope_type = self.rope_type
186+ # max_seq_len_cached = self.max_seq_len_cached
187 original_inv_freq = self.original_inv_freq
188 prefix = ""
189 else:
190- rope_type = self.rope_type[layer_type]
191- max_seq_len_cached = getattr(
192- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
193- )
194+ # rope_type = self.rope_type[layer_type]
195+ # max_seq_len_cached = getattr(
196+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
197+ # )
198 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
199 prefix = f"{layer_type}_"
200
201- if seq_len > max_seq_len_cached: # growth
202- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
203- inv_freq, self.attention_scaling = rope_init_fn(
204- self.config,
205- device,
206- seq_len=seq_len,
207- layer_type=layer_type,
208- )
209- # TODO joao: may break with compilation
210- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
211- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
212+ # Second test to translate.
213+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
214+ # But in that case the following condition is a way to restore the original cache.
215
216- if (
217- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
218- ): # reset
219- # This .to() is needed if the model has been moved to a device after being initialized (because
220- # the buffer is automatically moved, but not the original copy)
221- original_inv_freq = original_inv_freq.to(device)
222- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
223- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
224- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
225+ # if (
226+ # seq_len < self.original_max_seq_len
227+ # and self.max_seq_len_cached > self.original_max_seq_len
228+ # ):
229+ # self.original_inv_freq = self.original_inv_freq.to(device)
230+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
231+ # self.max_seq_len_cached = self.original_max_seq_len
232+
233+ original_inv_freq = self.original_inv_freq.to(device)
234+ cond = (seq_len >= self.original_max_seq_len).item()
235+ # PATCHED: uses torch.cond instead of a test
236+ inv_freq = torch.cond(
237+ cond,
238+ (lambda x, y: x.clone()),
239+ (lambda x, y: y.clone()),
240+ [long_inv_freq, original_inv_freq],
241+ )
242+ setattr(self, f"{prefix}inv_freq", inv_freq)
243
244 @wraps(rope_forward)
245 def wrapper(self, x, position_ids, layer_type=None):
246- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
247- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
248- if "dynamic" in rope_type:
249- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
250- elif rope_type == "longrope":
251- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
252- return rope_forward(self, x, position_ids, **kwargs)
253+ if layer_type is None:
254+ if "dynamic" in self.rope_type:
255+ dynamic_frequency_update(self, position_ids, device=x.device)
256+ elif self.rope_type == "longrope":
257+ longrope_frequency_update(self, position_ids, device=x.device)
258+ return rope_forward(self, x, position_ids)
259+
260+ if "dynamic" in self.rope_type:
261+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
262+ elif self.rope_type == "longrope":
263+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
264+ return rope_forward(self, x, position_ids, layer_type=layer_type)
265
266 return wrapper
transformers: eager_mask -> patched_eager_mask¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def eager_mask(
5+def patched_eager_mask(
6 batch_size: int,
7 cache_position: torch.Tensor,
8 kv_length: int,
9@@ -6,38 +6,14 @@
10 mask_function: Callable = causal_mask_function,
11 attention_mask: Optional[torch.Tensor] = None,
12 dtype: torch.dtype = torch.float32,
13- use_vmap: bool = False,
14 **kwargs,
15 ) -> torch.Tensor:
16- """
17- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
18- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
19- it should not.
20-
21- Args:
22- batch_size (`int`):
23- The batch size of the input sequence.
24- cache_position (`torch.Tensor`):
25- A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
26- kv_length (`int`):
27- The size that the key and value states will have during the attention computation.
28- kv_offset (`int`, optional):
29- An optional offset to indicate at which first position the key and values states will refer to.
30- mask_function (`Callable`):
31- The mask factory function describing the mask pattern.
32- attention_mask (`torch.Tensor`, optional):
33- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
34- dtype (`torch.dtype`, optional):
35- The dtype to use for the mask. By default, `torch.float32`.
36- use_vmap (`bool`, optional):
37- Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
38- index-based (for the cost of speed performance). By default `False`.
39- """
40+ """manual patch for function ``transformers.masking_utils.eager_mask``."""
41 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
42 _ = kwargs.pop("allow_is_causal_skip", None)
43 _ = kwargs.pop("allow_is_bidirectional_skip", None)
44- _ = kwargs.pop("allow_torch_fix", None)
45- mask = sdpa_mask(
46+ # PATCHED: this line called the patched version of sdpa_mask
47+ mask = patched_sdpa_mask_recent_torch(
48 batch_size=batch_size,
49 cache_position=cache_position,
50 kv_length=kv_length,
51@@ -47,10 +23,13 @@
52 allow_is_causal_skip=False,
53 allow_is_bidirectional_skip=False,
54 allow_torch_fix=False,
55- use_vmap=use_vmap,
56 **kwargs,
57 )
58 min_dtype = torch.finfo(dtype).min
59- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
60- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
61+ # PATCHED: the following line
62+ # we need 0s where the tokens should be taken into account,
63+ # and -inf otherwise (mask is already of boolean type)
64+ # mask =
65+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
66+ mask = (~mask).to(dtype) * min_dtype
67 return mask
transformers: sdpa_attention_forward -> patched_sdpa_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,4 +1,4 @@
4-def sdpa_attention_forward(
5+def patched_sdpa_attention_forward(
6 module: torch.nn.Module,
7 query: torch.Tensor,
8 key: torch.Tensor,
9@@ -9,57 +9,133 @@
10 is_causal: Optional[bool] = None,
11 **kwargs,
12 ) -> tuple[torch.Tensor, None]:
13- if kwargs.get("output_attentions", False):
14- logger.warning_once(
15- "`sdpa` attention does not support `output_attentions=True`."
16- " Please set your attention to `eager` if you want any of these features."
17- )
18+ """
19+ manual patch for function
20+ ``transformers.integrations.sdpa_attention.sdpa_attention_forward``
21+ """
22+ assert not kwargs.get("output_attentions", False), (
23+ "`sdpa` attention does not support `output_attentions=True`."
24+ " Please set your attention to `eager` if you want any of these features."
25+ )
26+ torch._check(
27+ query.shape[0] == key.shape[0] or query.shape[0] == 1,
28+ lambda: (
29+ f"broadcast issue query (1): {query.shape}, key: {key.shape}, "
30+ f"value: {value.shape}"
31+ ),
32+ )
33+ torch._check(
34+ key.shape[0] == value.shape[0] or key.shape[0] == 1,
35+ lambda: (
36+ f"broadcast issue query (2): {query.shape}, key: {key.shape}, "
37+ f"value: {value.shape}"
38+ ),
39+ )
40+
41 sdpa_kwargs = {}
42 if hasattr(module, "num_key_value_groups"):
43- if not use_gqa_in_sdpa(attention_mask, key):
44- key = repeat_kv(key, module.num_key_value_groups)
45- value = repeat_kv(value, module.num_key_value_groups)
46+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
47+ key = transformers.integrations.sdpa_attention.repeat_kv(
48+ key, module.num_key_value_groups
49+ )
50+ value = transformers.integrations.sdpa_attention.repeat_kv(
51+ value, module.num_key_value_groups
52+ )
53 else:
54 sdpa_kwargs = {"enable_gqa": True}
55
56- # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
57- is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
58+ if attention_mask is not None and attention_mask.ndim == 4:
59+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
60
61- # SDPA's Flash Attention (and cuDNN) kernels rely on the `is_causal` flag. However, there are certain conditions:
62- # - Not in decoding phase (otherwise we want full attention on the single query token)
63- # - Attention mask is not to be provided (even if it is a causal pattern)
64- # - Internally, we marked this as compatible with causal, i.e. it is a decoder attention type
65- #
66- # Quirks on the conditionals:
67- # - We avoid inline passing this to the SDPA function directly to support both torch.compile's dynamic shapes and
68- # full graph options. Otherwise, dynamic shapes are prevented from compiling.
69- # - It is important to check first for the shape, otherwise compile will fail with
70- # `argument 'is_causal' must be bool, not SymBool`.
71- is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
72+ torch._check(
73+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
74+ lambda: "Attention mask shape incompatible with key shape.",
75+ )
76
77- # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
78- # We convert it to a bool for the SDPA kernel that only accepts bools.
79- if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
80- is_causal = is_causal.item()
81+ if patch_sdpa_is_causal:
82+ # transformers>=4.55
83+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
84
85- # When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
86- # and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
87- # This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
88- if _is_torch_npu_available:
89- if attention_mask is not None and attention_mask.dtype != torch.bool:
90- # Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
91- attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
92+ # PATCHED: remove the test query.shape[2] > 1
93+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
94+ # and we split the test to keep the minimum in torch.cond
95+ is_causal = attention_mask is None and is_causal
96
97- attn_output = torch.nn.functional.scaled_dot_product_attention(
98- query,
99- key,
100- value,
101- attn_mask=attention_mask,
102- dropout_p=dropout,
103- scale=scaling,
104- is_causal=is_causal,
105- **sdpa_kwargs,
106+ if not is_causal:
107+ torch._check(query.shape[0] > 0)
108+ torch._check(query.shape[1] > 0)
109+ torch._check(query.shape[2] > 0)
110+ torch._check(query.shape[3] > 0)
111+ torch._check(key.shape[0] > 0)
112+ torch._check(key.shape[1] > 0)
113+ torch._check(key.shape[2] > 0)
114+ torch._check(key.shape[3] > 0)
115+ torch._check(value.shape[0] > 0)
116+ torch._check(value.shape[1] > 0)
117+ torch._check(value.shape[2] > 0)
118+ torch._check(value.shape[3] > 0)
119+ return (
120+ torch.nn.functional.scaled_dot_product_attention(
121+ query,
122+ key,
123+ value,
124+ attn_mask=attention_mask,
125+ dropout_p=dropout,
126+ scale=scaling,
127+ is_causal=is_causal,
128+ **sdpa_kwargs,
129+ )
130+ .transpose(1, 2)
131+ .contiguous(),
132+ None,
133+ )
134+ else:
135+ # transformers<4.55
136+ if is_causal is None and attention_mask is not None:
137+ is_causal = False
138+ if is_causal is not None:
139+ return (
140+ torch.nn.functional.scaled_dot_product_attention(
141+ query,
142+ key,
143+ value,
144+ attn_mask=attention_mask,
145+ dropout_p=dropout,
146+ scale=scaling,
147+ is_causal=is_causal,
148+ **sdpa_kwargs,
149+ )
150+ .transpose(1, 2)
151+ .contiguous(),
152+ None,
153+ )
154+
155+ # To avoid the following errors:
156+ # is_causal=query.shape[2] > 1
157+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool
158+ # is_causal=torch.tensor(query.shape[2] > 1)
159+ # TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor
160+ attn_output = torch.cond(
161+ query.shape[2] > 1, # distinction between prefill and decoding steps
162+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
163+ query,
164+ key,
165+ value,
166+ dropout_p=dropout,
167+ scale=scaling,
168+ is_causal=True,
169+ **sdpa_kwargs,
170+ ).contiguous(),
171+ lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
172+ query,
173+ key,
174+ value,
175+ dropout_p=dropout,
176+ scale=scaling,
177+ is_causal=False,
178+ **sdpa_kwargs,
179+ ).contiguous(),
180+ [query, key, value],
181 )
182 attn_output = attn_output.transpose(1, 2).contiguous()
183-
184 return attn_output, None
auto/patch_transformers: DynamicLayer.lazy_initialization -> patched_DynamicLayer.lazy_initialization¶
1--- original
2+++ rewritten
3@@ -1,5 +1,9 @@
4 def lazy_initialization(self, key_states: torch.Tensor):
5 self.dtype, self.device = key_states.dtype, key_states.device
6- self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
7- self.values = torch.tensor([], dtype=self.dtype, device=self.device)
8- self.is_initialized = True
9+ new_shape = list(key_states.shape)
10+ new_shape[-2] = 0
11+ # PATCHED: used a tensor with an empty shape and not en empty list to initialize
12+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
13+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
14+ if patch_is_initialized:
15+ self.is_initialized = True
auto/patch_transformers: Gemma2RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Gemma3Model.get_placeholder_mask -> patched_Gemma3Model.get_placeholder_mask¶
1--- original
2+++ rewritten
3@@ -4,14 +4,12 @@
4 inputs_embeds: torch.FloatTensor,
5 image_features: torch.FloatTensor,
6 ):
7- """
8- Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
9- equal to the length of multimodal features. If the lengths are different, an error is raised.
10- """
11 if input_ids is None:
12 special_image_mask = inputs_embeds == self.get_input_embeddings()(
13 torch.tensor(
14- self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device
15+ self.config.image_token_id,
16+ dtype=torch.long,
17+ device=inputs_embeds.device,
18 )
19 )
20 special_image_mask = special_image_mask.all(-1)
21@@ -23,8 +21,14 @@
22 special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
23 )
24 n_image_features = image_features.shape[0] * image_features.shape[1]
25- if inputs_embeds[special_image_mask].numel() != image_features.numel():
26- raise ValueError(
27- f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
28- )
29+ # PATCHED: torch._check
30+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
31+ # raise ValueError( ... )
32+ torch._check(
33+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
34+ lambda: (
35+ f"Image features and image tokens do not match: tokens: "
36+ f"{n_image_tokens}, features {n_image_features}"
37+ ),
38+ )
39 return special_image_mask
auto/patch_transformers: Gemma3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,13 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6+@patched_dynamic_rope_update
7 def forward(self, x, position_ids, layer_type=None):
8- inv_freq = getattr(self, f"{layer_type}_inv_freq")
9- attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
10+ if layer_type is not None:
11+ # transformers>=5.0
12+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
13+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
14+ else:
15+ # transformers<5.0
16+ inv_freq = self.inv_freq
17+ attention_scaling = self.attention_scaling
18
19 inv_freq_expanded = (
20 inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21
22--- original
23+++ rewritten
24@@ -1,99 +1,193 @@
25-def dynamic_rope_update(rope_forward):
26- """
27- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
28- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
29+def patched_dynamic_rope_update(rope_forward):
30+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
31
32- Args:
33- rope_forward (Callable):
34- The forward pass of the RoPE implementation.
35+ ``rope_type`` is determined in the constructor of class
36+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
37
38- Returns:
39- The decorated forward pass.
40+ .. code-block:: python
41+
42+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
43+ self.rope_type = config.rope_scaling.get(
44+ "rope_type", config.rope_scaling.get("type"))
45+ else:
46+ self.rope_type = "default"
47+
48+ The original code of the patched function:
49+
50+ .. code-block:: python
51+
52+ def dynamic_rope_update(rope_forward):
53+ def longrope_frequency_update(self, position_ids, device):
54+ seq_len = torch.max(position_ids) + 1
55+ if hasattr(self.config, "original_max_position_embeddings"):
56+ original_max_position_embeddings =
57+ self.config.original_max_position_embeddings
58+ else:
59+ original_max_position_embeddings =
60+ self.config.max_position_embeddings
61+ if seq_len > original_max_position_embeddings:
62+ if not hasattr(self, "long_inv_freq"):
63+ self.long_inv_freq, _ = self.rope_init_fn(
64+ self.config, device, seq_len=original_max_position_embeddings + 1
65+ )
66+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
67+ else:
68+ self.original_inv_freq = self.original_inv_freq.to(device)
69+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
70+
71+ def dynamic_frequency_update(self, position_ids, device):
72+ seq_len = torch.max(position_ids) + 1
73+ if seq_len > self.max_seq_len_cached: # growth
74+ inv_freq, self.attention_scaling = self.rope_init_fn(
75+ self.config, device, seq_len=seq_len)
76+ self.register_buffer("inv_freq", inv_freq, persistent=False)
77+ self.max_seq_len_cached = seq_len
78+
79+ if seq_len < self.original_max_seq_len and
80+ self.max_seq_len_cached > self.original_max_seq_len:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+ self.max_seq_len_cached = self.original_max_seq_len
84+
85+ @wraps(rope_forward)
86+ def wrapper(self, x, position_ids):
87+ if "dynamic" in self.rope_type:
88+ dynamic_frequency_update(self, position_ids, device=x.device)
89+ elif self.rope_type == "longrope":
90+ longrope_frequency_update(self, position_ids, device=x.device)
91+ return rope_forward(self, x, position_ids)
92+
93+ return wrapper
94+
95 """
96
97 def longrope_frequency_update(self, position_ids, device, layer_type=None):
98- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
99+ # It is no use to patch the function after the model is created
100+ # as rope_init_fn is an attribute set to one function when the model
101+ # is created and when no patch is applied yet.
102+ # So we select the patched version here.
103+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
104 seq_len = torch.max(position_ids) + 1
105- original_max_position_embeddings = getattr(
106- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
107- )
108+ if hasattr(self.config, "original_max_position_embeddings"):
109+ original_max_position_embeddings = self.config.original_max_position_embeddings
110+ else:
111+ original_max_position_embeddings = self.config.max_position_embeddings
112+
113 if layer_type is None:
114- rope_type = self.rope_type
115+ # rope_type = self.rope_type
116 original_inv_freq = self.original_inv_freq
117 prefix = ""
118 else:
119- rope_type = self.rope_type[layer_type]
120+ # rope_type = self.rope_type[layer_type]
121 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
122 prefix = f"{layer_type}_"
123
124- if seq_len > original_max_position_embeddings:
125- if not hasattr(self, f"{layer_type}_long_inv_freq"):
126- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
127- long_inv_freq, _ = rope_init_fn(
128- self.config,
129- device,
130- seq_len=original_max_position_embeddings + 1,
131- layer_type=layer_type,
132- )
133- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
134- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
135- else:
136- # This .to() is needed if the model has been moved to a device after being initialized (because
137- # the buffer is automatically moved, but not the original copy)
138- original_inv_freq = original_inv_freq.to(device)
139- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
140- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
141+ # At export time, seq_len is unknown.
142+ long_inv_freq, _ = rope_init_fn(
143+ self.config, device, seq_len=original_max_position_embeddings + 1
144+ )
145+ original_inv_freq = self.original_inv_freq.to(device)
146+
147+ # PATCHED: uses torch.cond instead of a test
148+ cond = (seq_len > original_max_position_embeddings).item()
149+ inv_freq = torch.cond(
150+ cond,
151+ (lambda x, y: x.clone()),
152+ (lambda x, y: y.clone()),
153+ [long_inv_freq, original_inv_freq],
154+ )
155+ setattr(self, f"{prefix}inv_freq", inv_freq)
156+ # if seq_len > original_max_position_embeddings:
157+ # self.inv_freq = self.long_inv_freq
158+ # else:
159+ # self.inv_freq = self.original_inv_freq
160
161 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
162- """
163- dynamic RoPE layers should recompute `inv_freq` in the following situations:
164- 1 - growing beyond the cached sequence length (allow scaling)
165- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
166- """
167+ # constructor:
168+ # - self.max_seq_len_cached = config.max_position_embeddings
169+ # - self.original_max_seq_len = config.max_position_embeddings
170+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
171+
172+ # It is no use to patch the function after the model is created
173+ # as rope_init_fn is an attribute set to one function when the model
174+ # is created and when no patch is applied yet.
175+ # So we select the patched version here.
176+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
177+
178+ # This behaviour is difficult to translate.
179+ # The sequence always grows.
180+ # The test should always True.
181+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
182+ #
183+ # if seq_len > self.max_seq_len_cached: # growth
184+ # inv_freq, self.attention_scaling = self.rope_init_fn(
185+ # self.config, device, seq_len=seq_len
186+ # )
187+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
188+ # self.max_seq_len_cached = seq_len
189+ #
190+ # So we should not need what follows.
191+ #
192+ # cond = (seq_len > self.max_seq_len_cached).item()
193+ # self.attention_scaling = torch.cond(
194+ # cond,
195+ # (lambda x, y: x.clone()),
196+ # (lambda x, y: y.clone()),
197+ # [attention_scaling, self.attention_scaling],
198+ # )
199+
200 seq_len = torch.max(position_ids) + 1
201+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
202+
203 if layer_type is None:
204- rope_type = self.rope_type
205- max_seq_len_cached = self.max_seq_len_cached
206+ # rope_type = self.rope_type
207+ # max_seq_len_cached = self.max_seq_len_cached
208 original_inv_freq = self.original_inv_freq
209 prefix = ""
210 else:
211- rope_type = self.rope_type[layer_type]
212- max_seq_len_cached = getattr(
213- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
214- )
215+ # rope_type = self.rope_type[layer_type]
216+ # max_seq_len_cached = getattr(
217+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
218+ # )
219 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
220 prefix = f"{layer_type}_"
221
222- if seq_len > max_seq_len_cached: # growth
223- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
224- inv_freq, self.attention_scaling = rope_init_fn(
225- self.config,
226- device,
227- seq_len=seq_len,
228- layer_type=layer_type,
229- )
230- # TODO joao: may break with compilation
231- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
232- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
233+ # Second test to translate.
234+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
235+ # But in that case the following condition is a way to restore the original cache.
236
237- if (
238- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
239- ): # reset
240- # This .to() is needed if the model has been moved to a device after being initialized (because
241- # the buffer is automatically moved, but not the original copy)
242- original_inv_freq = original_inv_freq.to(device)
243- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
244- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
245- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
246+ # if (
247+ # seq_len < self.original_max_seq_len
248+ # and self.max_seq_len_cached > self.original_max_seq_len
249+ # ):
250+ # self.original_inv_freq = self.original_inv_freq.to(device)
251+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
252+ # self.max_seq_len_cached = self.original_max_seq_len
253+
254+ original_inv_freq = self.original_inv_freq.to(device)
255+ cond = (seq_len >= self.original_max_seq_len).item()
256+ # PATCHED: uses torch.cond instead of a test
257+ inv_freq = torch.cond(
258+ cond,
259+ (lambda x, y: x.clone()),
260+ (lambda x, y: y.clone()),
261+ [long_inv_freq, original_inv_freq],
262+ )
263+ setattr(self, f"{prefix}inv_freq", inv_freq)
264
265 @wraps(rope_forward)
266 def wrapper(self, x, position_ids, layer_type=None):
267- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
268- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
269- if "dynamic" in rope_type:
270- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
271- elif rope_type == "longrope":
272- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
273- return rope_forward(self, x, position_ids, **kwargs)
274+ if layer_type is None:
275+ if "dynamic" in self.rope_type:
276+ dynamic_frequency_update(self, position_ids, device=x.device)
277+ elif self.rope_type == "longrope":
278+ longrope_frequency_update(self, position_ids, device=x.device)
279+ return rope_forward(self, x, position_ids)
280+
281+ if "dynamic" in self.rope_type:
282+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
283+ elif self.rope_type == "longrope":
284+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
285+ return rope_forward(self, x, position_ids, layer_type=layer_type)
286
287 return wrapper
auto/patch_transformers: GemmaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation -> patched_GenerationMixin._cache_dependant_input_preparation¶
1--- original
2+++ rewritten
3@@ -3,23 +3,29 @@
4 input_ids: torch.LongTensor,
5 inputs_embeds: Optional[torch.FloatTensor],
6 cache_position: Optional[torch.LongTensor],
7-) -> tuple[torch.FloatTensor, torch.LongTensor]:
8+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
9 """
10 Generic cache-dependent input preparation
11 The code is put in a separate function to allow granular unit testing
12 as it needs a different implementation to be exportable.
13
14- If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
15- - Exception 1: when passing input_embeds, input_ids may be missing entries
16- - Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
17- - Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
18- - Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
19- generate the first token for each sequence. Later use the generated Input ids for continuation.
20+ If we have cache: let's slice `input_ids` through `cache_position`,
21+ to keep only the unprocessed tokens
22+ - Exception 1: when passing input_embeds,
23+ input_ids may be missing entries
24+ - Exception 2: some generation methods do special slicing of input_ids,
25+ so we don't need to do it here
26+ - Exception 3: with synced GPUs cache_position may go out of bounds,
27+ but we only want dummy token in that case.
28+ - Exception 4: If input_embeds are passed then slice it through
29+ `cache_position`, to keep only the unprocessed tokens and
30+ generate the first token for each sequence.
31+ Later use the generated Input ids for continuation.
32
33 The current implementation does not rely on ``self`` and could be
34 a class method. It is left as a standard method to be easily rewritten.
35 """
36- if is_torchdynamo_exporting():
37+ if _is_torchdynamo_exporting():
38 return self._cache_dependant_input_preparation_exporting(
39 input_ids, inputs_embeds, cache_position
40 )
auto/patch_transformers: GenerationMixin._cache_dependant_input_preparation_exporting -> patched_GenerationMixin._cache_dependant_input_preparation_exporting¶
1--- original
2+++ rewritten
3@@ -3,7 +3,7 @@
4 input_ids: torch.LongTensor,
5 inputs_embeds: Optional[torch.FloatTensor],
6 cache_position: Optional[torch.LongTensor],
7-) -> tuple[torch.FloatTensor, torch.LongTensor]:
8+) -> Tuple[torch.FloatTensor, torch.LongTensor]:
9 """
10 This method implements method ``_cache_dependant_input_preparation``
11 with :func:`torch.cond` to make it exportable with :func:`torch.export.export`.
12@@ -21,7 +21,6 @@
13 # else:
14 # if input_ids.shape[1] != cache_position.shape[0]:
15 # input_ids = input_ids[:, cache_position]
16- # We need to clone the outputs to avoid aliasing.
17 def branch_1(inputs_embeds, cache_position):
18 return inputs_embeds[:, -cache_position.shape[0] :].clone()
19
20@@ -49,7 +48,7 @@
21 torch.cond(
22 input_ids.shape[1] != cache_position.shape[0],
23 branch_3,
24- (lambda input_ids, cache_position: input_ids.clone()),
25+ (lambda input_ids, cache_position: input_ids),
26 [input_ids, cache_position],
27 )
28 ),
auto/patch_transformers: IdeficsAttention.forward -> patched_IdeficsAttention.forward¶
1--- original
2+++ rewritten
3@@ -4,10 +4,12 @@
4 key_value_states: Optional[torch.Tensor] = None,
5 attention_mask: Optional[torch.Tensor] = None,
6 position_ids: Optional[torch.LongTensor] = None,
7- past_key_values: Optional[Cache] = None,
8+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
9+ output_attentions: bool = False,
10+ use_cache: bool = False,
11 cache_position: Optional[torch.LongTensor] = None,
12- **kwargs: Unpack[TransformersKwargs],
13-) -> tuple[torch.Tensor, torch.Tensor]:
14+ **kwargs,
15+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
16 # if key_value_states are provided this layer is used as a cross-attention layer
17 is_cross_attention = self.is_cross_attention or key_value_states is not None
18
19@@ -43,20 +45,27 @@
20 )
21
22 kv_seq_len = key_states.shape[-2]
23- if past_key_values is not None:
24+ if past_key_value is not None:
25 kv_seq_len += cache_position[0]
26
27 if not is_cross_attention:
28- cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
29- query_states, key_states = apply_rotary_pos_emb(
30- query_states, key_states, cos, sin, position_ids
31+ rotary_length = torch.maximum(
32+ torch.tensor(kv_seq_len, dtype=torch.int64),
33+ torch.tensor(q_len, dtype=torch.int64),
34+ )
35+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_length)
36+ query_states, key_states = (
37+ transformers.models.idefics.modeling_idefics.apply_rotary_pos_emb(
38+ query_states, key_states, cos, sin, position_ids
39+ )
40 )
41 # [bsz, nh, t, hd]
42
43- if past_key_values is not None:
44- # sin and cos are specific to RoPE models; cache_position needed for the static cache
45+ if past_key_value is not None:
46+ # sin and cos are specific to RoPE models;
47+ # cache_position needed for the static cache
48 cache_kwargs = {"cache_position": cache_position}
49- key_states, value_states = past_key_values.update(
50+ key_states, value_states = past_key_value.update(
51 key_states, value_states, self.layer_idx, cache_kwargs
52 )
53
54@@ -64,10 +73,22 @@
55 query_states = self.q_layer_norm(query_states)
56 key_states = self.k_layer_norm(key_states)
57
58- attention_interface: Callable = eager_attention_forward
59+ attention_interface: Callable = (
60+ transformers.models.idefics.modeling_idefics.eager_attention_forward
61+ )
62
63 if self.config._attn_implementation != "eager":
64- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
65+ if self.config._attn_implementation == "sdpa" and output_attentions:
66+ transformers.models.idefics.modeling_idefics.logger.warning_once(
67+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
68+ "`output_attentions=True`. Falling back to "
69+ "eager attention. This warning can be removed using the argument "
70+ '`attn_implementation="eager"` when loading the model.'
71+ )
72+ else:
73+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
74+ self.config._attn_implementation
75+ ]
76
77 attn_output, attn_weights = attention_interface(
78 self,
79@@ -83,4 +104,9 @@
80 attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
81 attn_output = self.o_proj(attn_output)
82
83+ if output_attentions:
84+ attn_weights = None
85+
86+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
87+ return attn_output, attn_weights, past_key_value
88 return attn_output, attn_weights
auto/patch_transformers: IdeficsEmbedding.forward -> patched_IdeficsEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,9 +1,26 @@
4 def forward(self, x, seq_len=None):
5 # x: [bs, num_attention_heads, seq_len, head_size]
6- if seq_len > self.max_seq_len_cached:
7- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
8+ # if seq_len > self.max_seq_len_cached:
9+ # self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
10
11- return (
12- self.cos_cached[:seq_len].to(dtype=x.dtype),
13- self.sin_cached[:seq_len].to(dtype=x.dtype),
14+ def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
15+ t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
16+ # freqs = torch.einsum("i,j->ij", t, inv_freq)
17+ freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
18+ emb = torch.cat((freqs, freqs), dim=-1)
19+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
20+
21+ def _set_cos_sin_cache_else(_x, _inv_freq, _seq_len, cos_cached, sin_cached):
22+ torch._check(seq_len.item() <= cos_cached.shape[0])
23+ co = cos_cached[: seq_len.item()].detach().clone()
24+ torch._check(seq_len.item() <= sin_cached.shape[0])
25+ si = sin_cached[: seq_len.item()].detach().clone()
26+ return co.to(dtype=x.dtype), si.to(dtype=x.dtype)
27+
28+ cos_cached, sin_cached = torch.cond(
29+ (seq_len > self.max_seq_len_cached).item(),
30+ _set_cos_sin_cache_then,
31+ _set_cos_sin_cache_else,
32+ [x, self.inv_freq, seq_len, self.cos_cached, self.sin_cached],
33 )
34+ return cos_cached, sin_cached
auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: MistralRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: MixtralRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Phi3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Phi4MultimodalRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: PhiRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: Qwen2_5_VLForConditionalGeneration.prepare_inputs_for_generation -> patched_Qwen2_5_VLForConditionalGeneration.prepare_inputs_for_generation¶
1--- original
2+++ rewritten
3@@ -14,9 +14,12 @@
4 second_per_grid_ts=None,
5 **kwargs,
6 ):
7- # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
8+ # Overwritten -- in specific circumstances we don't want to f
9+ # forward image inputs to the model
10+ from transformers.generation import GenerationMixin
11
12- model_inputs = super().prepare_inputs_for_generation(
13+ model_inputs = GenerationMixin.prepare_inputs_for_generation(
14+ self,
15 input_ids,
16 past_key_values=past_key_values,
17 attention_mask=attention_mask,
18@@ -36,8 +39,8 @@
19 if position_ids is None:
20 # Calculate RoPE index once per generation in the pre-fill stage only.
21 # When compiling, we can't check tensor values thus we check only input length
22- # It is safe to assume that `length!=1` means we're in pre-fill because compiled
23- # models currently cannot do assisted decoding
24+ # It is safe to assume that `length!=1` means we're in pre-fill
25+ # because compiled models currently cannot do assisted decoding
26 if cache_position[0] == 0 or self.model.rope_deltas is None:
27 vision_positions, rope_deltas = self.model.get_rope_index(
28 model_inputs.get("input_ids", None),
29@@ -48,7 +51,7 @@
30 )
31 self.model.rope_deltas = rope_deltas
32 # then use the prev pre-calculated rope-deltas to get the correct position ids
33- elif "position_ids" in model_inputs:
34+ elif "position_ids" in model_inputs and model_inputs["position_ids"] is not None:
35 batch_size, seq_length = model_inputs["position_ids"].shape
36 device = model_inputs["position_ids"].device
37 position_ids = torch.arange(seq_length, device=device)
38@@ -58,7 +61,14 @@
39 vision_positions = position_ids + delta.expand_as(position_ids)
40
41 # Concatenate "text + vision" positions into [4, bs, seq-len]
42- text_positions = model_inputs["position_ids"][None, ...]
43+ if "position_ids" not in model_inputs or model_inputs["position_ids"] is None:
44+ text_positions = torch.arange(input_ids.shape[1], device=input_ids.device)[
45+ None, None, :
46+ ]
47+ else:
48+ text_positions = model_inputs["position_ids"][None, ...]
49+ # text_positions = model_inputs["position_ids"][None, ...]
50+ assert vision_positions is not None, "vision_positions are missing"
51 model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
52
53 if cache_position[0] != 0:
auto/patch_transformers: Qwen2_5_VLVisionAttention.forward -> patched_Qwen2_5_VLVisionAttention.forward¶
1--- original
2+++ rewritten
3@@ -7,24 +7,59 @@
4 **kwargs,
5 ) -> torch.Tensor:
6 seq_length = hidden_states.shape[0]
7- query_states, key_states, value_states = (
8- self.qkv(hidden_states)
9- .reshape(seq_length, 3, self.num_heads, -1)
10- .permute(1, 0, 2, 3)
11- .unbind(0)
12+ # PATCHED: avoid the use of unbind
13+ qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3)
14+
15+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
16+ cos, sin = position_embeddings
17+
18+ # This part should be moved into the loop
19+ # iteration to enable fusion inside the loop.
20+ query_states, key_states = (
21+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_vision(
22+ query_states, key_states, cos, sin
23+ )
24 )
25- cos, sin = position_embeddings
26- query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
27
28 query_states = query_states.transpose(0, 1).unsqueeze(0)
29 key_states = key_states.transpose(0, 1).unsqueeze(0)
30 value_states = value_states.transpose(0, 1).unsqueeze(0)
31
32- attention_interface: Callable = eager_attention_forward
33+ attention_interface: Callable = (
34+ transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
35+ )
36 if self.config._attn_implementation != "eager":
37- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
38+ # PATCHED
39+ # attention_interface = ALL_ATTENTION_FUNCTIONS[
40+ # self.config._attn_implementation]
41+ attention_interface = transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[
42+ self.config._attn_implementation
43+ ]
44
45- if self.config._attn_implementation == "flash_attention_2":
46+ if self.config._attn_implementation == "flash_attention_2" and _is_torchdynamo_exporting():
47+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
48+ attn_output = torch.onnx.ops.symbolic(
49+ "custom::qwen25_attention",
50+ (
51+ query_states,
52+ key_states,
53+ value_states,
54+ cu_seqlens,
55+ cu_seqlens,
56+ max_seqlen,
57+ max_seqlen,
58+ torch.tensor(self.scaling, dtype=torch.float32),
59+ ),
60+ dtype=query_states.dtype,
61+ shape=(
62+ key_states.shape[0],
63+ value_states.shape[1],
64+ max_seqlen,
65+ value_states.shape[-1],
66+ ),
67+ version=1,
68+ )
69+ elif self.config._attn_implementation == "flash_attention_2":
70 # Flash Attention 2: Use cu_seqlens for variable length attention
71 max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
72 attn_output, _ = attention_interface(
73@@ -42,6 +77,67 @@
74 is_causal=False,
75 **kwargs,
76 )
77+ elif _is_torchdynamo_exporting():
78+ if attention_interface is transformers.integrations.sdpa_attention.sdpa_attention_forward:
79+ attention_interface = patched_sdpa_attention_forward
80+
81+ if use_loop_for_attention_in_qwen_2_5:
82+
83+ def _iteration(start_end, query_states, key_states, value_states):
84+ return patched_Qwen2_5_VLVisionAttentionOneIteration.forward(
85+ self,
86+ start_end,
87+ query_states,
88+ key_states,
89+ value_states,
90+ scaling=self.scaling,
91+ dropout=0.0 if not self.training else self.attention_dropout,
92+ )
93+
94+ starts = cu_seqlens[:-1]
95+ ends = cu_seqlens[1:]
96+ # cu_seqlens = [0, 10, 14, 27]
97+ # starts: [0, 10, 14]
98+ # ends: [10, 14, 17]
99+ # starts_ends: [[0, 10], [10, 14], [14, 27]]
100+ starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], dim=1)
101+ attn_outputs = [
102+ _iteration(start_end, query_states, key_states, value_states)
103+ for start_end in starts_ends
104+ ]
105+ # attn_outputs = torch._higher_order_ops.while_loop(
106+ # attn_outputs = torch.ops.higher_order.while_loop(
107+ # (lambda it, starts_ends, *_args: it < starts_ends.shape[0]),
108+ # _iteration,
109+ # (torch.tensor(0),
110+ # starts_ends, query_states, key_states, value_states), tuple(),
111+ # )
112+ attn_output = torch.cat(attn_outputs, dim=1)
113+ else:
114+ # make square mask
115+ indices = torch.arange(
116+ cu_seqlens.max(), dtype=cu_seqlens.dtype, device=cu_seqlens.device
117+ )
118+ dot = (cu_seqlens.unsqueeze(1) <= indices.unsqueeze(0)).to(cu_seqlens.dtype)
119+ dot = dot.sum(dim=0)
120+ mask = dot.unsqueeze(1) - dot.unsqueeze(0)
121+ bool_mask = mask == 0
122+ bool_mask = bool_mask.unsqueeze(0).unsqueeze(0)
123+
124+ torch._check(bool_mask.shape[2] == key_states.shape[2])
125+ torch._check(bool_mask.shape[3] == key_states.shape[2])
126+
127+ attn_output, _ = attention_interface(
128+ self,
129+ query_states,
130+ key_states,
131+ value_states,
132+ attention_mask=bool_mask,
133+ scaling=self.scaling,
134+ dropout=0.0 if not self.training else self.attention_dropout,
135+ is_causal=False,
136+ **kwargs,
137+ )
138 else:
139 # Other implementations: Process each chunk separately
140 lengths = cu_seqlens[1:] - cu_seqlens[:-1]
auto/patch_transformers: Qwen2_5_VisionTransformerPretrainedModel.get_window_index -> patched_Qwen2_5_VisionTransformerPretrainedModel.get_window_index¶
1--- original
2+++ rewritten
3@@ -1,10 +1,15 @@
4 def get_window_index(self, grid_thw):
5- window_index: list = []
6- cu_window_seqlens: list = [0]
7+ window_index: list = [] # type: ignore[annotation-unchecked]
8+ # PATCHED
9+ cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int64)] # type: ignore[annotation-unchecked]
10 window_index_id = 0
11 vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
12
13- for grid_t, grid_h, grid_w in grid_thw:
14+ for _thw in grid_thw:
15+ # PATCHED: avoid unbind
16+ grid_t = _thw[0]
17+ grid_h = _thw[1]
18+ grid_w = _thw[2]
19 llm_grid_h, llm_grid_w = (
20 grid_h // self.spatial_merge_size,
21 grid_w // self.spatial_merge_size,
22@@ -34,9 +39,11 @@
23 index_padded = index_padded.reshape(-1)
24 index_new = index_padded[index_padded != -100]
25 window_index.append(index_new + window_index_id)
26- cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
27- cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
28+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1:]
29+ # PATCHED
30+ # cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
31+ cu_window_seqlens.append(cu_seqlens_tmp)
32 window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
33 window_index = torch.cat(window_index, dim=0)
34
35- return window_index, cu_window_seqlens
36+ return window_index, torch.cat(cu_window_seqlens, dim=0)
auto/patch_transformers: Qwen2_5_VisionTransformerPretrainedModel.forward -> patched_Qwen2_5_VisionTransformerPretrainedModel.forward¶
1--- original
2+++ rewritten
3@@ -12,11 +12,13 @@
4 hidden_states = self.patch_embed(hidden_states)
5 rotary_pos_emb = self.rot_pos_emb(grid_thw)
6 window_index, cu_window_seqlens = self.get_window_index(grid_thw)
7- cu_window_seqlens = torch.tensor(
8- cu_window_seqlens,
9- device=hidden_states.device,
10- dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
11- )
12+ # PATCHED
13+ # cu_window_seqlens = torch.tensor(
14+ # cu_window_seqlens,
15+ # device=hidden_states.device,
16+ # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
17+ # )
18+ cu_window_seqlens = cu_window_seqlens.to(hidden_states.device).to(grid_thw.dtype)
19 cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
20
21 seq_len, _ = hidden_states.size()
22@@ -37,8 +39,10 @@
23 dim=0,
24 # Select dtype based on the following factors:
25 # - FA2 requires that cu_seqlens_q must have dtype int32
26- # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
27- # See https://github.com/huggingface/transformers/pull/34852 for more information
28+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype
29+ # as grid_thw
30+ # See https://github.com/huggingface/transformers/pull/34852
31+ # for more information
32 dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
33 )
34 cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
35@@ -59,5 +63,4 @@
36 hidden_states = self.merger(hidden_states)
37 reverse_indices = torch.argsort(window_index)
38 hidden_states = hidden_states[reverse_indices, :]
39-
40 return hidden_states
auto/patch_transformers: Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb -> patched_Qwen2_5_VisionTransformerPretrainedModel.rot_pos_emb¶
1--- original
2+++ rewritten
3@@ -1,6 +1,10 @@
4 def rot_pos_emb(self, grid_thw):
5 pos_ids = []
6- for t, h, w in grid_thw:
7+ for thw_ in grid_thw:
8+ # PATCHED: avoid unbind
9+ t = thw_[0]
10+ h = thw_[1]
11+ w = thw_[2]
12 hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
13 hpos_ids = hpos_ids.reshape(
14 h // self.spatial_merge_size,
auto/patch_transformers: Qwen3MoeSparseMoeBlock.forward -> patched_Qwen3MoeSparseMoeBlock.forward¶
1--- original
2+++ rewritten
3@@ -1,6 +1,67 @@
4-def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
5+def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6+ """ """
7 batch_size, sequence_length, hidden_dim = hidden_states.shape
8- hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
9- routing_weights, selected_experts = self.router(hidden_states_reshaped)
10- final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
11- return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
12+ hidden_states = hidden_states.view(-1, hidden_dim)
13+ # router_logits: (batch * sequence_length, n_experts)
14+ router_logits = self.gate(hidden_states)
15+
16+ routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
17+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
18+ if self.norm_topk_prob: # only diff with mixtral sparse moe block!
19+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
20+ # we cast back to the input dtype
21+ routing_weights = routing_weights.to(hidden_states.dtype)
22+
23+ final_hidden_states = torch.zeros(
24+ (batch_size * sequence_length, hidden_dim),
25+ dtype=hidden_states.dtype,
26+ device=hidden_states.device,
27+ )
28+
29+ # One hot encode the selected experts to create an expert mask
30+ # this will be used to easily index which expert is going to be sollicitated
31+ expert_mask = torch.nn.functional.one_hot(
32+ selected_experts, num_classes=self.num_experts
33+ ).permute(2, 1, 0)
34+
35+ # Loop over all available experts in the model
36+ # and perform the computation on each expert
37+ expert_sum = expert_mask.sum(dim=(-1, -2))
38+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
39+ # for expert_idx in expert_hit:
40+ for expert_idx in range(self.num_experts):
41+ # initial code has a squeeze but it is not possible to do that.
42+ # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
43+ expert_mask_idx = expert_mask[expert_idx]
44+ final_hidden_states = torch.cond(
45+ (expert_sum[expert_idx] > 0).item(),
46+ lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
47+ final_hidden_states,
48+ expert_mask,
49+ hidden_states,
50+ routing_weights,
51+ expert_idx=_i,
52+ ),
53+ lambda final_hidden_states, *args: final_hidden_states.clone(),
54+ [final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
55+ )
56+
57+ # if expert_sum[expert_idx] > 0:
58+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
59+
60+ # Index the correct hidden states and compute the expert hidden state for
61+ # the current expert. We need to make sure to multiply the output hidden
62+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
63+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
64+ # current_hidden_states = (
65+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
66+ # )
67+
68+ # However `index_add_` only support torch tensors for indexing so we'll use
69+ # the `top_x` tensor here.
70+ # final_hidden_states.index_add_(
71+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
72+ # )
73+
74+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
75+ return final_hidden_states, router_logits
auto/patch_transformers: ‘Qwen3MoeSparseMoeBlock_forward_expert_loop’ -> patched_Qwen3MoeSparseMoeBlock._forward_expert_loop¶
1def _forward_expert_loop(
2 self,
3 final_hidden_states,
4 expert_mask_idx,
5 hidden_states,
6 routing_weights,
7 expert_idx: int,
8):
9 # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
10 idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
11 hidden_dim = hidden_states.shape[-1]
12 current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
13 expert_current_state = self.experts[expert_idx](current_state)
14 current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
15 return final_hidden_states.index_add(0, top_x, current_hidden_states.to(hidden_states.dtype))
auto/patch_transformers: SamMaskDecoder.forward -> patched_SamMaskDecoder.forward¶
1--- original
2+++ rewritten
3@@ -5,6 +5,7 @@
4 sparse_prompt_embeddings: torch.Tensor,
5 dense_prompt_embeddings: torch.Tensor,
6 multimask_output: bool,
7+ output_attentions: Optional[bool] = None,
8 attention_similarity: Optional[torch.Tensor] = None,
9 target_embedding: Optional[torch.Tensor] = None,
10 ) -> tuple[torch.Tensor, torch.Tensor]:
11@@ -22,19 +23,31 @@
12 the embeddings of the mask inputs
13 multimask_output (bool):
14 Whether to return multiple masks or a single mask.
15+ output_attentions (bool, *optional*):
16+ Whether or not to return the attentions tensors of all attention layers.
17 """
18 batch_size, num_channels, height, width = image_embeddings.shape
19- point_batch_size = (
20- sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
21- )
22+ point_batch_size = sparse_prompt_embeddings.shape[1]
23 # Concatenate output tokens
24 output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
25 output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
26
27- if sparse_prompt_embeddings is not None:
28- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
29- else:
30- tokens = output_tokens
31+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
32+ # torch.any is needed to avoid data-dependent control flow
33+ # with sparse_prompt_embeddings.sum().item() != 0
34+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
35+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
36+
37+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
38+ return output_tokens.clone()
39+
40+ tokens = torch.cond(
41+ torch.any(sparse_prompt_embeddings != 0),
42+ sparse_prompt_embeddings_is_not_empty,
43+ sparse_prompt_embeddings_is_empty,
44+ [output_tokens, sparse_prompt_embeddings],
45+ )
46+
47 point_embeddings = tokens.to(self.iou_token.weight.dtype)
48
49 # Expand per-image data in batch direction to be per-point
50@@ -45,15 +58,21 @@
51 )
52
53 # Run the transformer, image_positional_embedding are consumed
54- point_embedding, image_embeddings = self.transformer(
55+ torch._check(point_embeddings.shape[0] != 0)
56+ torch._check(point_embeddings.shape[1] != 0)
57+ torch._check(point_embeddings.shape[2] != 0)
58+ torch._check(point_embeddings.shape[3] != 0)
59+ embeddings_attentions = self.transformer(
60 point_embeddings=point_embeddings,
61 image_embeddings=image_embeddings,
62 image_positional_embeddings=image_positional_embeddings,
63 attention_similarity=attention_similarity,
64 target_embedding=target_embedding,
65+ output_attentions=output_attentions,
66 )
67- iou_token_out = point_embedding[:, :, 0, :]
68- mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
69+ point_embedding, image_embeddings = embeddings_attentions[:2]
70+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
71+ mask_tokens_out = torch.narrow(point_embedding, dim=2, start=1, length=self.num_mask_tokens)
72
73 # Upscale mask embeddings and predict masks using the mask tokens
74 image_embeddings = image_embeddings.transpose(2, 3).reshape(
75@@ -88,4 +107,15 @@
76 mask_slice = slice(0, 1)
77 masks = masks[:, :, mask_slice, :, :]
78 iou_pred = iou_pred[:, :, mask_slice]
79- return masks, iou_pred
80+
81+ outputs = (masks, iou_pred)
82+
83+ if len(embeddings_attentions) == 2:
84+ # transformers==4.54
85+ return outputs
86+
87+ if output_attentions and len(embeddings_attentions) > 2:
88+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
89+ else:
90+ outputs = outputs + (None,) # noqa: RUF005
91+ return outputs
auto/patch_transformers: SmolLM3RotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original
2+++ rewritten
3@@ -1,8 +1,16 @@
4-@torch.no_grad()
5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
6-def forward(self, x, position_ids):
7+@patched_dynamic_rope_update
8+def forward(self, x, position_ids, layer_type=None):
9+ if layer_type is not None:
10+ # transformers>=5.0
11+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
13+ else:
14+ # transformers<5.0
15+ inv_freq = self.inv_freq
16+ attention_scaling = self.attention_scaling
17+
18 inv_freq_expanded = (
19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
21 )
22 position_ids_expanded = position_ids[:, None, :].float()
23
24@@ -12,7 +20,7 @@
25 with torch.autocast(device_type=device_type, enabled=False): # Force float32
26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
27 emb = torch.cat((freqs, freqs), dim=-1)
28- cos = emb.cos() * self.attention_scaling
29- sin = emb.sin() * self.attention_scaling
30+ cos = emb.cos() * attention_scaling
31+ sin = emb.sin() * attention_scaling
32
33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
34
35--- original
36+++ rewritten
37@@ -1,99 +1,193 @@
38-def dynamic_rope_update(rope_forward):
39- """
40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
42+def patched_dynamic_rope_update(rope_forward):
43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
44
45- Args:
46- rope_forward (Callable):
47- The forward pass of the RoPE implementation.
48+ ``rope_type`` is determined in the constructor of class
49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
50
51- Returns:
52- The decorated forward pass.
53+ .. code-block:: python
54+
55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
56+ self.rope_type = config.rope_scaling.get(
57+ "rope_type", config.rope_scaling.get("type"))
58+ else:
59+ self.rope_type = "default"
60+
61+ The original code of the patched function:
62+
63+ .. code-block:: python
64+
65+ def dynamic_rope_update(rope_forward):
66+ def longrope_frequency_update(self, position_ids, device):
67+ seq_len = torch.max(position_ids) + 1
68+ if hasattr(self.config, "original_max_position_embeddings"):
69+ original_max_position_embeddings =
70+ self.config.original_max_position_embeddings
71+ else:
72+ original_max_position_embeddings =
73+ self.config.max_position_embeddings
74+ if seq_len > original_max_position_embeddings:
75+ if not hasattr(self, "long_inv_freq"):
76+ self.long_inv_freq, _ = self.rope_init_fn(
77+ self.config, device, seq_len=original_max_position_embeddings + 1
78+ )
79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
80+ else:
81+ self.original_inv_freq = self.original_inv_freq.to(device)
82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
83+
84+ def dynamic_frequency_update(self, position_ids, device):
85+ seq_len = torch.max(position_ids) + 1
86+ if seq_len > self.max_seq_len_cached: # growth
87+ inv_freq, self.attention_scaling = self.rope_init_fn(
88+ self.config, device, seq_len=seq_len)
89+ self.register_buffer("inv_freq", inv_freq, persistent=False)
90+ self.max_seq_len_cached = seq_len
91+
92+ if seq_len < self.original_max_seq_len and
93+ self.max_seq_len_cached > self.original_max_seq_len:
94+ self.original_inv_freq = self.original_inv_freq.to(device)
95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
96+ self.max_seq_len_cached = self.original_max_seq_len
97+
98+ @wraps(rope_forward)
99+ def wrapper(self, x, position_ids):
100+ if "dynamic" in self.rope_type:
101+ dynamic_frequency_update(self, position_ids, device=x.device)
102+ elif self.rope_type == "longrope":
103+ longrope_frequency_update(self, position_ids, device=x.device)
104+ return rope_forward(self, x, position_ids)
105+
106+ return wrapper
107+
108 """
109
110 def longrope_frequency_update(self, position_ids, device, layer_type=None):
111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+ # It is no use to patch the function after the model is created
113+ # as rope_init_fn is an attribute set to one function when the model
114+ # is created and when no patch is applied yet.
115+ # So we select the patched version here.
116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117 seq_len = torch.max(position_ids) + 1
118- original_max_position_embeddings = getattr(
119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120- )
121+ if hasattr(self.config, "original_max_position_embeddings"):
122+ original_max_position_embeddings = self.config.original_max_position_embeddings
123+ else:
124+ original_max_position_embeddings = self.config.max_position_embeddings
125+
126 if layer_type is None:
127- rope_type = self.rope_type
128+ # rope_type = self.rope_type
129 original_inv_freq = self.original_inv_freq
130 prefix = ""
131 else:
132- rope_type = self.rope_type[layer_type]
133+ # rope_type = self.rope_type[layer_type]
134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135 prefix = f"{layer_type}_"
136
137- if seq_len > original_max_position_embeddings:
138- if not hasattr(self, f"{layer_type}_long_inv_freq"):
139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140- long_inv_freq, _ = rope_init_fn(
141- self.config,
142- device,
143- seq_len=original_max_position_embeddings + 1,
144- layer_type=layer_type,
145- )
146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148- else:
149- # This .to() is needed if the model has been moved to a device after being initialized (because
150- # the buffer is automatically moved, but not the original copy)
151- original_inv_freq = original_inv_freq.to(device)
152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+ # At export time, seq_len is unknown.
155+ long_inv_freq, _ = rope_init_fn(
156+ self.config, device, seq_len=original_max_position_embeddings + 1
157+ )
158+ original_inv_freq = self.original_inv_freq.to(device)
159+
160+ # PATCHED: uses torch.cond instead of a test
161+ cond = (seq_len > original_max_position_embeddings).item()
162+ inv_freq = torch.cond(
163+ cond,
164+ (lambda x, y: x.clone()),
165+ (lambda x, y: y.clone()),
166+ [long_inv_freq, original_inv_freq],
167+ )
168+ setattr(self, f"{prefix}inv_freq", inv_freq)
169+ # if seq_len > original_max_position_embeddings:
170+ # self.inv_freq = self.long_inv_freq
171+ # else:
172+ # self.inv_freq = self.original_inv_freq
173
174 def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175- """
176- dynamic RoPE layers should recompute `inv_freq` in the following situations:
177- 1 - growing beyond the cached sequence length (allow scaling)
178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179- """
180+ # constructor:
181+ # - self.max_seq_len_cached = config.max_position_embeddings
182+ # - self.original_max_seq_len = config.max_position_embeddings
183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+ # It is no use to patch the function after the model is created
186+ # as rope_init_fn is an attribute set to one function when the model
187+ # is created and when no patch is applied yet.
188+ # So we select the patched version here.
189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+ # This behaviour is difficult to translate.
192+ # The sequence always grows.
193+ # The test should always True.
194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+ #
196+ # if seq_len > self.max_seq_len_cached: # growth
197+ # inv_freq, self.attention_scaling = self.rope_init_fn(
198+ # self.config, device, seq_len=seq_len
199+ # )
200+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
201+ # self.max_seq_len_cached = seq_len
202+ #
203+ # So we should not need what follows.
204+ #
205+ # cond = (seq_len > self.max_seq_len_cached).item()
206+ # self.attention_scaling = torch.cond(
207+ # cond,
208+ # (lambda x, y: x.clone()),
209+ # (lambda x, y: y.clone()),
210+ # [attention_scaling, self.attention_scaling],
211+ # )
212+
213 seq_len = torch.max(position_ids) + 1
214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216 if layer_type is None:
217- rope_type = self.rope_type
218- max_seq_len_cached = self.max_seq_len_cached
219+ # rope_type = self.rope_type
220+ # max_seq_len_cached = self.max_seq_len_cached
221 original_inv_freq = self.original_inv_freq
222 prefix = ""
223 else:
224- rope_type = self.rope_type[layer_type]
225- max_seq_len_cached = getattr(
226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227- )
228+ # rope_type = self.rope_type[layer_type]
229+ # max_seq_len_cached = getattr(
230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+ # )
232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233 prefix = f"{layer_type}_"
234
235- if seq_len > max_seq_len_cached: # growth
236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237- inv_freq, self.attention_scaling = rope_init_fn(
238- self.config,
239- device,
240- seq_len=seq_len,
241- layer_type=layer_type,
242- )
243- # TODO joao: may break with compilation
244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+ # Second test to translate.
247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+ # But in that case the following condition is a way to restore the original cache.
249
250- if (
251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252- ): # reset
253- # This .to() is needed if the model has been moved to a device after being initialized (because
254- # the buffer is automatically moved, but not the original copy)
255- original_inv_freq = original_inv_freq.to(device)
256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+ # if (
260+ # seq_len < self.original_max_seq_len
261+ # and self.max_seq_len_cached > self.original_max_seq_len
262+ # ):
263+ # self.original_inv_freq = self.original_inv_freq.to(device)
264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+ # self.max_seq_len_cached = self.original_max_seq_len
266+
267+ original_inv_freq = self.original_inv_freq.to(device)
268+ cond = (seq_len >= self.original_max_seq_len).item()
269+ # PATCHED: uses torch.cond instead of a test
270+ inv_freq = torch.cond(
271+ cond,
272+ (lambda x, y: x.clone()),
273+ (lambda x, y: y.clone()),
274+ [long_inv_freq, original_inv_freq],
275+ )
276+ setattr(self, f"{prefix}inv_freq", inv_freq)
277
278 @wraps(rope_forward)
279 def wrapper(self, x, position_ids, layer_type=None):
280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282- if "dynamic" in rope_type:
283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284- elif rope_type == "longrope":
285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286- return rope_forward(self, x, position_ids, **kwargs)
287+ if layer_type is None:
288+ if "dynamic" in self.rope_type:
289+ dynamic_frequency_update(self, position_ids, device=x.device)
290+ elif self.rope_type == "longrope":
291+ longrope_frequency_update(self, position_ids, device=x.device)
292+ return rope_forward(self, x, position_ids)
293+
294+ if "dynamic" in self.rope_type:
295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+ elif self.rope_type == "longrope":
297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+ return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300 return wrapper
auto/patch_transformers: VisionAttention.forward -> patched_VisionAttention.forward¶
1--- original
2+++ rewritten
3@@ -3,69 +3,55 @@
4 hidden_states: torch.Tensor,
5 cu_seqlens: torch.Tensor,
6 rotary_pos_emb: Optional[torch.Tensor] = None,
7- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
8- **kwargs,
9+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
10 ) -> torch.Tensor:
11 seq_length = hidden_states.shape[0]
12- query_states, key_states, value_states = (
13+ q, k, v = (
14 self.qkv(hidden_states)
15 .reshape(seq_length, 3, self.num_heads, -1)
16 .permute(1, 0, 2, 3)
17 .unbind(0)
18 )
19- cos, sin = position_embeddings
20- query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
21+ if position_embeddings is None:
22+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
23+ "The attention layers in this model are transitioning from "
24+ " computing the RoPE embeddings internally "
25+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
26+ "to using externally computed "
27+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
28+ " In v4.54 `rotary_pos_emb` will be "
29+ "removed and `position_embeddings` will be mandatory."
30+ )
31+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
32+ cos = emb.cos()
33+ sin = emb.sin()
34+ else:
35+ cos, sin = position_embeddings
36+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
37+ q, k, cos, sin
38+ )
39
40- query_states = query_states.transpose(0, 1).unsqueeze(0)
41- key_states = key_states.transpose(0, 1).unsqueeze(0)
42- value_states = value_states.transpose(0, 1).unsqueeze(0)
43+ attention_mask = torch.full(
44+ [1, seq_length, seq_length],
45+ torch.finfo(q.dtype).min,
46+ device=q.device,
47+ dtype=q.dtype,
48+ )
49+ # for i in range(1, len(cu_seqlens)):
50+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
51+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
52+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
53
54- attention_interface: Callable = eager_attention_forward
55- if self.config._attn_implementation != "eager":
56- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
57-
58- if self.config._attn_implementation == "flash_attention_2":
59- # Flash Attention 2: Use cu_seqlens for variable length attention
60- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
61- attn_output, _ = attention_interface(
62- self,
63- query_states,
64- key_states,
65- value_states,
66- attention_mask=None,
67- scaling=self.scaling,
68- dropout=0.0 if not self.training else self.attention_dropout,
69- cu_seq_lens_q=cu_seqlens,
70- cu_seq_lens_k=cu_seqlens,
71- max_length_q=max_seqlen,
72- max_length_k=max_seqlen,
73- is_causal=False,
74- **kwargs,
75- )
76- else:
77- # Other implementations: Process each chunk separately
78- lengths = cu_seqlens[1:] - cu_seqlens[:-1]
79- splits = [
80- torch.split(tensor, lengths.tolist(), dim=2)
81- for tensor in (query_states, key_states, value_states)
82- ]
83-
84- attn_outputs = [
85- attention_interface(
86- self,
87- q,
88- k,
89- v,
90- attention_mask=None,
91- scaling=self.scaling,
92- dropout=0.0 if not self.training else self.attention_dropout,
93- is_causal=False,
94- **kwargs,
95- )[0]
96- for q, k, v in zip(*splits)
97- ]
98- attn_output = torch.cat(attn_outputs, dim=1)
99-
100- attn_output = attn_output.reshape(seq_length, -1).contiguous()
101+ q = q.transpose(0, 1)
102+ k = k.transpose(0, 1)
103+ v = v.transpose(0, 1)
104+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
105+ attn_weights = attn_weights + attention_mask
106+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
107+ q.dtype
108+ )
109+ attn_output = torch.matmul(attn_weights, v)
110+ attn_output = attn_output.transpose(0, 1)
111+ attn_output = attn_output.reshape(seq_length, -1)
112 attn_output = self.proj(attn_output)
113 return attn_output
auto/patch_transformers: eager_attention_forward -> patched_model_bart_eager_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,27 +1,23 @@
4-def eager_attention_forward(
5- module: nn.Module,
6+def patched_model_bart_eager_attention_forward(
7+ module: torch.nn.Module,
8 query: torch.Tensor,
9 key: torch.Tensor,
10 value: torch.Tensor,
11 attention_mask: Optional[torch.Tensor],
12 scaling: Optional[float] = None,
13 dropout: float = 0.0,
14- **kwargs: Unpack[TransformersKwargs],
15+ head_mask: Optional[torch.Tensor] = None,
16+ **kwargs,
17 ):
18- if scaling is None:
19- scaling = query.size(-1) ** -0.5
20-
21- # Take the dot product between "query" and "key" to get the raw attention scores.
22- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24- if attention_mask is not None:
25- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26- attn_weights = attn_weights + attention_mask
27-
28- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31- attn_output = torch.matmul(attn_weights, value)
32- attn_output = attn_output.transpose(1, 2).contiguous()
33-
34- return attn_output, attn_weights
35+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
36+ return common_eager_attention_forward(
37+ module,
38+ query,
39+ key,
40+ value,
41+ attention_mask=attention_mask,
42+ scaling=scaling,
43+ dropout=dropout,
44+ head_mask=head_mask,
45+ **kwargs,
46+ )
auto/patch_transformers: eager_attention_forward -> patched_modeling_marian_eager_attention_forward¶
1--- original
2+++ rewritten
3@@ -1,27 +1,23 @@
4-def eager_attention_forward(
5- module: nn.Module,
6+def patched_modeling_marian_eager_attention_forward(
7+ module: torch.nn.Module,
8 query: torch.Tensor,
9 key: torch.Tensor,
10 value: torch.Tensor,
11 attention_mask: Optional[torch.Tensor],
12 scaling: Optional[float] = None,
13 dropout: float = 0.0,
14- **kwargs: Unpack[TransformersKwargs],
15+ head_mask: Optional[torch.Tensor] = None,
16+ **kwargs,
17 ):
18- if scaling is None:
19- scaling = query.size(-1) ** -0.5
20-
21- # Take the dot product between "query" and "key" to get the raw attention scores.
22- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
23-
24- if attention_mask is not None:
25- attention_mask = attention_mask[:, :, :, : key.shape[-2]]
26- attn_weights = attn_weights + attention_mask
27-
28- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
29- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
30-
31- attn_output = torch.matmul(attn_weights, value)
32- attn_output = attn_output.transpose(1, 2).contiguous()
33-
34- return attn_output, attn_weights
35+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
36+ return common_eager_attention_forward(
37+ module,
38+ query,
39+ key,
40+ value,
41+ attention_mask=attention_mask,
42+ scaling=scaling,
43+ dropout=dropout,
44+ head_mask=head_mask,
45+ **kwargs,
46+ )