Patches List#
torch#
<<<
import textwrap
from yobx.torch import apply_patches_for_model
from yobx.torch.in_torch.patches import PATCHES
with apply_patches_for_model(patch_torch=True):
link = []
rows = []
for i, patch in enumerate(PATCHES):
name = f"patch-torch-{i+1}"
link.append(f"* :ref:`{name}`")
rows.extend(
[
"",
f".. _{name}:",
"",
patch.patch.__qualname__,
"-" * len(patch.patch.__qualname__),
"",
"::",
"",
textwrap.indent(patch.format_diff(), " "),
"",
]
)
print("\n".join([*link, "", *rows]))
>>>
patched_DynamicDimConstraintPrinter._print_Symbol#
torch: DynamicDimConstraintPrinter._print_Symbol -> patched_DynamicDimConstraintPrinter._print_Symbol
--- original
+++ rewritten
@@ -1,6 +1,5 @@
def _print_Symbol(self, expr: sympy.Symbol) -> str:
- if not isinstance(expr, sympy.Symbol):
- raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}")
- if not self.symbol_to_source.get(expr):
- raise AssertionError(f"Unknown symbol {expr} created by constraints solver")
- return self.symbol_to_source[expr][0].name
+ assert isinstance(expr, sympy.Symbol), str(type(expr))
+ if self.symbol_to_source.get(expr): # type: ignore
+ return self.symbol_to_source[expr][0].name # type: ignore
+ return str(expr)
patched_infer_size#
torch: infer_size -> patched_infer_size
--- original
+++ rewritten
@@ -1,10 +1,13 @@
-def infer_size(a: Sequence[IntLikeType], b: Sequence[IntLikeType]) -> tuple[IntLikeType, ...]:
- from torch.fx.experimental.symbolic_shapes import guard_or_false
-
+def patched_infer_size(a, b):
+ """
+ Patches ``torch._subclasses.fake_impls.infer_size``.
+ This patch is needed to export
+ :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`.
+ """
dimsA = len(a)
dimsB = len(b)
ndim = max(dimsA, dimsB)
- expandedSizes: list[IntLikeType] = [0] * ndim
+ expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
@@ -23,11 +26,21 @@
# expression of an or statement as-is, without bool()'ing it; if this
# were not the case, we'd need to write this using torch.sym_or() or
# something like that).
- torch._check(
- guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB,
- lambda: f"The size of tensor a ({sizeA}) "
- f"must match the size of tensor b ({sizeB}) "
- f"at non-singleton dimension {i})",
- )
- expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
+ try:
+ b1 = fx_symbolic_shapes.guard_or_false(sizeA == 1)
+ except fx_symbolic_shapes.GuardOnDataDependentSymNode:
+ b1 = False
+ try:
+ b2 = fx_symbolic_shapes.guard_or_false(sizeB == 1)
+ except fx_symbolic_shapes.GuardOnDataDependentSymNode:
+ b2 = False
+ try:
+ b3 = fx_symbolic_shapes.guard_or_false(sizeA == sizeB)
+ except fx_symbolic_shapes.GuardOnDataDependentSymNode:
+ b3 = False
+ if b1 or b2 or b3:
+ expandedSizes[i] = sizeB if fx_symbolic_shapes.guard_or_false(sizeA == 1) else sizeA
+ else:
+ # PATCHED: generic case, the dimension is known, no need to assert
+ expandedSizes[i] = torch.sym_max(sizeA, sizeB) # type: ignore
return tuple(expandedSizes)
patched__broadcast_shapes#
torch: _broadcast_shapes -> patched__broadcast_shapes
--- original
+++ rewritten
@@ -1,12 +1,11 @@
-def _broadcast_shapes(*_shapes):
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- guarding_hint_or_throw,
- has_guarding_hint,
- is_nested_int,
- )
-
- backed_so = torch.fx.experimental._config.backed_size_oblivious
+def patched__broadcast_shapes(*_shapes):
+ """
+ Patches ``torch._refs._broadcast_shapes``.
+ This patch is needed to export
+ :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`.
+ """
+ from functools import reduce
+ from torch._prims_common import IntLike
shapes = tuple(
(x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
@@ -19,63 +18,36 @@
for shape in shapes:
if not isinstance(shape, Sequence):
raise RuntimeError(
- "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
+ "Input shapes should be of type ints, a tuple of ints, "
+ "or a list of ints, got ",
shape,
)
# Computes common shape
- common_shape: list[int | torch.SymInt] = [
- 1,
- ] * reduce(max, (len(shape) for shape in shapes))
- for arg_idx, shape in enumerate(shapes):
+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
+ for _arg_idx, shape in enumerate(shapes):
for idx in range(-1, -1 - len(shape), -1):
- # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1).
- if is_nested_int(shape[idx]):
+ if fx_symbolic_shapes.is_nested_int(shape[idx]):
# Broadcasting is allowed for (j0, 1) or (j0, j0);
# not (j0, j1), (j0, 5), etc.
- if is_nested_int(common_shape[idx]) and guard_or_false(
- shape[idx] == common_shape[idx]
- ):
+ if fx_symbolic_shapes.is_nested_int(
+ common_shape[idx]
+ ) and fx_symbolic_shapes.guard_or_false(shape[idx] == common_shape[idx]):
continue
else:
- # When backed size oblivious is used, we specialize for broadcasting
- # if its the only way to compile the example input.
- # i.e: s0:1, s1:1 ==>
- # assert s0==s1, no specialization on ==1 or !=1.
- # The non-broadcast path is picked
- # s0:1, s1:4 ==>
- # specialize(s0) to be 1.
- # s0:4, s1:1 ==>
- # specialize(s1) to be 1.
- if (
- backed_so
- and has_guarding_hint(shape[idx])
- and has_guarding_hint(common_shape[idx])
- ):
- a = guarding_hint_or_throw(shape[idx])
- b = guarding_hint_or_throw(common_shape[idx])
- if a == 1 and b != 1:
- torch._check(shape[idx] == 1)
- if b == 1 and a != 1:
- torch._check(common_shape[idx] == 1)
- if guard_or_false(shape[idx] == common_shape[idx]):
+ if fx_symbolic_shapes.guard_or_false(shape[idx] == common_shape[idx]):
continue
-
- if guard_or_false(common_shape[idx] == 1):
+ # PATCHED: two cases, if == for sure, no broadcast,
+ # otherwise maybe broadcast with max(dimensions)
+ if fx_symbolic_shapes.guard_or_false(common_shape[idx] != 1):
+ pass
+ elif fx_symbolic_shapes.guard_or_false(
+ common_shape[idx] == 1
+ ) or fx_symbolic_shapes.guard_or_false(shape[idx] != 1):
if shape[idx] < 0:
raise ValueError("Attempting to broadcast a dimension with negative length!")
- common_shape[idx] = shape[idx]
-
- if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1):
- # broadcast case .
- continue
+ common_shape[idx] = shape[idx] # type: ignore
else:
- # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
- torch._check(
- common_shape[idx] == shape[idx],
- lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
- f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
- f"should be broadcastable to {common_shape}",
- )
+ common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx]) # type: ignore
return common_shape
patched__get_range_constraints#
torch: _get_range_constraints -> patched__get_range_constraints
--- original
+++ rewritten
@@ -1,42 +1,36 @@
-def _get_range_constraints(
+def patched__get_range_constraints(
mod: torch.nn.Module,
- export_artifact: ExportArtifact,
+ export_artifact: torch.export._trace.ExportArtifact,
args,
kwargs,
dynamic_shapes,
):
+ """
+ Patches ``torch.export._trace._get_range_constraints``.
+ See PR `#174593 <https://github.com/pytorch/pytorch/pull/174593>`_.
+ """
gm: torch.fx.GraphModule = export_artifact.aten.gm
- export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
+ export_graph_signature: torch.export.graph_signature.ExportGraphSignature = (
+ export_artifact.aten.sig
+ )
fake_mode: FakeTensorMode = export_artifact.fake_mode
num_lifted = next(
(
i
for i, s in enumerate(export_graph_signature.input_specs)
- if s.kind == InputKind.USER_INPUT
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
),
len(export_graph_signature.input_specs),
)
- combined_args = _combine_args(mod, args, kwargs)
- # This is because we trace based on the kwargs passed in from user
+ # preserve_order=True:
+ # this is because we trace based on the kwargs passed in from user
# not based on the signature. I feel it would be better to just enforce
# one ordering at the start of tracing to avoid confusions, but that is
# bigger refactor, so do this to unblock for now.
- combined_args_traced_order = {}
- for arg in combined_args:
- if arg not in kwargs:
- combined_args_traced_order[arg] = combined_args[arg]
+ combined_args = _combine_args(mod, args, kwargs, preserve_order=True)
- for key in kwargs:
- combined_args_traced_order[key] = kwargs[key]
-
- combined_args = combined_args_traced_order
-
- range_constraints = make_constraints(
- fake_mode,
- gm,
- combined_args,
- dynamic_shapes,
- num_lifted,
+ range_constraints = torch._export.non_strict_utils.make_constraints(
+ fake_mode, gm, combined_args, dynamic_shapes, num_lifted
)
return range_constraints
patched__maybe_broadcast#
torch: _maybe_broadcast -> patched__maybe_broadcast
--- original
+++ rewritten
@@ -1,15 +1,18 @@
-def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
+def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
+ """
+ Patches ``torch._refs._maybe_broadcast``.
+ This patch is needed to export
+ :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`.
+ """
+ from torch._prims_common import ShapeType, TensorLike, Number
+
# Computes common shape
- common_shape = _broadcast_shapes(
+ common_shape = patched__broadcast_shapes(
*(t.shape if isinstance(t, TensorLike) else None for t in args)
)
def should_expand(a: ShapeType, b: ShapeType) -> bool:
- from torch.fx.experimental.symbolic_shapes import (
- guard_or_false,
- sym_and,
- sym_or,
- )
+ from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and, sym_or
if len(a) != len(b):
return True
@@ -29,10 +32,15 @@
return True
# u0==u1 assume the same, no broadcasting!
- torch._check(
- x == y,
- lambda: "sizes assumed to be the same due to unbacked broadcasting semantics",
- )
+ # PATCHED: avoid errors
+ return True # guard_or_true(x != y)
+ # torch._check(
+ # x == y,
+ # lambda x=x, y=y: (
+ # f"sizes assumed to be the same due to unbacked "
+ # f"broadcasting semantics x={x!r}, y={y!r}"
+ # ),
+ # )
return False
@@ -42,14 +50,14 @@
elif isinstance(x, Number):
return x
elif isinstance(x, TensorLike):
- if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x): # type: ignore
return x
- if should_expand(x.shape, common_shape):
- return x.expand(common_shape)
+ if should_expand(x.shape, common_shape): # type: ignore
+ return x.expand(common_shape) # type: ignore
return x
else:
- raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
return tuple(__maybe_broadcast(x, common_shape) for x in args)
transformers#
<<<
import textwrap
from yobx.torch import apply_patches_for_model
from yobx.torch.in_torch.patches import PATCHES
with apply_patches_for_model(patch_transformers=True) as details:
link = []
rows = []
for i, patch in enumerate(details):
name = f"patch-transformers-{i+1}"
link.append(f"* :ref:`{name}`")
rows.extend(
[
"",
f".. _{name}:",
"",
patch.patch.__qualname__,
"-" * len(patch.patch.__qualname__),
"",
"::",
"",
textwrap.indent(patch.format_diff(), " "),
"",
]
)
print("\n".join([*link, "", *rows]))
>>>
<frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute * common_RotaryEmbedding.forward
common_RotaryEmbedding.forward#
transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward
--- original
+++ rewritten
@@ -1,18 +1,26 @@
-@torch.no_grad()
-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
-def forward(self, x, position_ids):
+@patched_dynamic_rope_update
+def forward(self, x, position_ids, layer_type=None):
+ if layer_type is not None:
+ # transformers>=5.0
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
+ else:
+ # transformers<5.0
+ inv_freq = self.inv_freq
+ attention_scaling = self.attention_scaling
+
inv_freq_expanded = (
- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = (
x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
)
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
+ cos = emb.cos() * attention_scaling
+ sin = emb.sin() * attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
--- original
+++ rewritten
@@ -1,103 +1,193 @@
-def dynamic_rope_update(rope_forward):
- """
- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
+def patched_dynamic_rope_update(rope_forward):
+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
- Args:
- rope_forward (Callable):
- The forward pass of the RoPE implementation.
+ ``rope_type`` is determined in the constructor of class
+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
- Returns:
- The decorated forward pass.
+ .. code-block:: python
+
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get(
+ "rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+
+ The original code of the patched function:
+
+ .. code-block:: python
+
+ def dynamic_rope_update(rope_forward):
+ def longrope_frequency_update(self, position_ids, device):
+ seq_len = torch.max(position_ids) + 1
+ if hasattr(self.config, "original_max_position_embeddings"):
+ original_max_position_embeddings =
+ self.config.original_max_position_embeddings
+ else:
+ original_max_position_embeddings =
+ self.config.max_position_embeddings
+ if seq_len > original_max_position_embeddings:
+ if not hasattr(self, "long_inv_freq"):
+ self.long_inv_freq, _ = self.rope_init_fn(
+ self.config, device, seq_len=original_max_position_embeddings + 1
+ )
+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
+ else:
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+
+ def dynamic_frequency_update(self, position_ids, device):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and
+ self.max_seq_len_cached > self.original_max_seq_len:
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @wraps(rope_forward)
+ def wrapper(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device)
+ return rope_forward(self, x, position_ids)
+
+ return wrapper
+
"""
def longrope_frequency_update(self, position_ids, device, layer_type=None):
- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
+ # It is no use to patch the function after the model is created
+ # as rope_init_fn is an attribute set to one function when the model
+ # is created and when no patch is applied yet.
+ # So we select the patched version here.
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
seq_len = torch.max(position_ids) + 1
+ if hasattr(self.config, "original_max_position_embeddings"):
+ original_max_position_embeddings = self.config.original_max_position_embeddings
+ else:
+ original_max_position_embeddings = self.config.max_position_embeddings
if layer_type is None:
- rope_type = self.rope_type
- original_inv_freq = self.original_inv_freq
- prefix = ""
- original_max_position_embeddings = self.config.rope_parameters[
- "original_max_position_embeddings"
- ]
- else:
- rope_type = self.rope_type[layer_type]
- original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
- prefix = f"{layer_type}_"
- original_max_position_embeddings = self.config.rope_parameters[layer_type][
- "original_max_position_embeddings"
- ]
-
- if seq_len > original_max_position_embeddings:
- if not hasattr(self, f"{layer_type}_long_inv_freq"):
- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
- long_inv_freq, _ = rope_init_fn(
- self.config,
- device,
- seq_len=original_max_position_embeddings + 1,
- layer_type=layer_type,
- )
- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
- setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
- else:
- # This .to() is needed if the model has been moved to a device after being initialized (because
- # the buffer is automatically moved, but not the original copy)
- original_inv_freq = original_inv_freq.to(device)
- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
-
- def dynamic_frequency_update(self, position_ids, device, layer_type=None):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if layer_type is None:
- rope_type = self.rope_type
- max_seq_len_cached = self.max_seq_len_cached
+ # rope_type = self.rope_type
original_inv_freq = self.original_inv_freq
prefix = ""
else:
- rope_type = self.rope_type[layer_type]
- max_seq_len_cached = getattr(
- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
- )
+ # rope_type = self.rope_type[layer_type]
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
prefix = f"{layer_type}_"
- if seq_len > max_seq_len_cached: # growth
- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(
- self.config,
- device,
- seq_len=seq_len,
- layer_type=layer_type,
- )
- # TODO joao: may break with compilation
- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
+ # At export time, seq_len is unknown.
+ long_inv_freq, _ = rope_init_fn(
+ self.config, device, seq_len=original_max_position_embeddings + 1
+ )
+ original_inv_freq = self.original_inv_freq.to(device)
- if (
- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
- ): # reset
- # This .to() is needed if the model has been moved to a device after being initialized (because
- # the buffer is automatically moved, but not the original copy)
- original_inv_freq = original_inv_freq.to(device)
- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
- setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
+ # PATCHED: uses torch.cond instead of a test
+ cond = (seq_len > original_max_position_embeddings).item()
+ inv_freq = torch.cond(
+ cond,
+ (lambda x, y: x.clone()),
+ (lambda x, y: y.clone()),
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+ )
+ setattr(self, f"{prefix}inv_freq", inv_freq)
+ # if seq_len > original_max_position_embeddings:
+ # self.inv_freq = self.long_inv_freq
+ # else:
+ # self.inv_freq = self.original_inv_freq
+
+ def dynamic_frequency_update(self, position_ids, device, layer_type=None):
+ # constructor:
+ # - self.max_seq_len_cached = config.max_position_embeddings
+ # - self.original_max_seq_len = config.max_position_embeddings
+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+
+ # It is no use to patch the function after the model is created
+ # as rope_init_fn is an attribute set to one function when the model
+ # is created and when no patch is applied yet.
+ # So we select the patched version here.
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
+
+ # This behaviour is difficult to translate.
+ # The sequence always grows.
+ # The test should always True.
+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
+ #
+ # if seq_len > self.max_seq_len_cached: # growth
+ # inv_freq, self.attention_scaling = self.rope_init_fn(
+ # self.config, device, seq_len=seq_len
+ # )
+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
+ # self.max_seq_len_cached = seq_len
+ #
+ # So we should not need what follows.
+ #
+ # cond = (seq_len > self.max_seq_len_cached).item()
+ # self.attention_scaling = torch.cond(
+ # cond,
+ # (lambda x, y: x.clone()),
+ # (lambda x, y: y.clone()),
+ # [attention_scaling, self.attention_scaling],
+ # )
+
+ seq_len = torch.max(position_ids) + 1
+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
+
+ if layer_type is None:
+ # rope_type = self.rope_type
+ # max_seq_len_cached = self.max_seq_len_cached
+ original_inv_freq = self.original_inv_freq
+ prefix = ""
+ else:
+ # rope_type = self.rope_type[layer_type]
+ # max_seq_len_cached = getattr(
+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
+ # )
+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
+ prefix = f"{layer_type}_"
+
+ # Second test to translate.
+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
+ # But in that case the following condition is a way to restore the original cache.
+
+ # if (
+ # seq_len < self.original_max_seq_len
+ # and self.max_seq_len_cached > self.original_max_seq_len
+ # ):
+ # self.original_inv_freq = self.original_inv_freq.to(device)
+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ # self.max_seq_len_cached = self.original_max_seq_len
+
+ original_inv_freq = self.original_inv_freq.to(device)
+ cond = (seq_len >= self.original_max_seq_len).item()
+ # PATCHED: uses torch.cond instead of a test
+ inv_freq = torch.cond(
+ cond,
+ (lambda x, y: x.clone()),
+ (lambda x, y: y.clone()),
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
+ )
+ setattr(self, f"{prefix}inv_freq", inv_freq)
@wraps(rope_forward)
def wrapper(self, x, position_ids, layer_type=None):
- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
- kwargs = {"layer_type": layer_type} if layer_type is not None else {}
- if "dynamic" in rope_type:
- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
- elif rope_type == "longrope":
- longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
- return rope_forward(self, x, position_ids, **kwargs)
+ if layer_type is None:
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device)
+ return rope_forward(self, x, position_ids)
+
+ if "dynamic" in self.rope_type:
+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+ elif self.rope_type == "longrope":
+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
+ return rope_forward(self, x, position_ids, layer_type=layer_type)
return wrapper