.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_torch/plot_patch_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_torch_plot_patch_model.py: .. _l-plot-patch-model-diff: Applying patches to a model and displaying the diff ===================================================== Before exporting a PyTorch model with :func:`torch.export.export`, a set of **patches** must be applied to work around limitations in the PyTorch exporter. This example shows how to: 1. Apply those patches with :func:`apply_patches_for_model `. 2. Inspect the registered :class:`PatchDetails ` object that is yielded by the context manager. 3. Display a unified diff for each :class:`PatchInfo ` so you can see exactly what changed in the original PyTorch internals. 4. Render the diff text as a matplotlib figure so that sphinx-gallery captures the example. 5. Show which patches were actually exercised when exporting a real model (`arnir0/Tiny-LLM`). The context manager both **applies** the patches on entry and **removes** them on exit, so the original functions are restored once the ``with`` block ends. .. GENERATED FROM PYTHON SOURCE LINES 27-34 .. code-block:: Python import torch from yobx import doc from yobx.helpers.patch_helper import PatchDetails from yobx.torch import apply_patches_for_model, register_flattening_functions, use_dyn_not_str from yobx.torch.tiny_models import get_tiny_model .. GENERATED FROM PYTHON SOURCE LINES 35-46 1. Apply patches and inspect PatchDetails ------------------------------------------ :func:`apply_patches_for_model` accepts two boolean flags: * ``patch_torch=True`` — patches several internal PyTorch functions that prevent successful dynamic-shape export. * ``patch_transformers=True`` — adds extra patches for 🤗 Transformers models. The context manager yields a :class:`PatchDetails` instance that lists every :class:`PatchInfo` that was applied. .. GENERATED FROM PYTHON SOURCE LINES 46-53 .. code-block:: Python with apply_patches_for_model(patch_torch=True) as details: assert isinstance(details, PatchDetails) print(f"Number of patches applied: {details.n_patches}") for patch in details: print(f" [{patch.family}] {patch.name}") .. rst-class:: sphx-glr-script-out .. code-block:: none Number of patches applied: 5 [torch] _print_Symbol [torch] patched_infer_size [torch] patched__broadcast_shapes [torch] patched__get_range_constraints [torch] patched__maybe_broadcast .. GENERATED FROM PYTHON SOURCE LINES 54-64 2. Display the diff for each patch ------------------------------------ After the ``with`` block the patches have been removed, but :meth:`PatchInfo.format_diff` still works because the original function reference is retained internally. Each diff is a standard ``unified diff`` — lines starting with ``-`` were in the original function; lines starting with ``+`` are in the patched version. .. GENERATED FROM PYTHON SOURCE LINES 64-69 .. code-block:: Python for patch in details: print(patch.format_diff(format="raw")) print() .. rst-class:: sphx-glr-script-out .. code-block:: none torch: DynamicDimConstraintPrinter._print_Symbol -> patched_DynamicDimConstraintPrinter._print_Symbol --- original +++ rewritten @@ -1,6 +1,5 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: - if not isinstance(expr, sympy.Symbol): - raise AssertionError(f"Expected sympy.Symbol, got {type(expr)}") - if not self.symbol_to_source.get(expr): - raise AssertionError(f"Unknown symbol {expr} created by constraints solver") - return self.symbol_to_source[expr][0].name + assert isinstance(expr, sympy.Symbol), str(type(expr)) + if self.symbol_to_source.get(expr): # type: ignore + return self.symbol_to_source[expr][0].name # type: ignore + return str(expr) torch: infer_size -> patched_infer_size --- original +++ rewritten @@ -1,10 +1,13 @@ -def infer_size(a: Sequence[IntLikeType], b: Sequence[IntLikeType]) -> tuple[IntLikeType, ...]: - from torch.fx.experimental.symbolic_shapes import guard_or_false - +def patched_infer_size(a, b): + """ + Patches ``torch._subclasses.fake_impls.infer_size``. + This patch is needed to export + :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`. + """ dimsA = len(a) dimsB = len(b) ndim = max(dimsA, dimsB) - expandedSizes: list[IntLikeType] = [0] * ndim + expandedSizes = [0] * ndim for i in range(ndim - 1, -1, -1): offset = ndim - 1 - i dimA = dimsA - 1 - offset @@ -23,11 +26,21 @@ # expression of an or statement as-is, without bool()'ing it; if this # were not the case, we'd need to write this using torch.sym_or() or # something like that). - torch._check( - guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB, - lambda: f"The size of tensor a ({sizeA}) " - f"must match the size of tensor b ({sizeB}) " - f"at non-singleton dimension {i})", - ) - expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA + try: + b1 = fx_symbolic_shapes.guard_or_false(sizeA == 1) + except fx_symbolic_shapes.GuardOnDataDependentSymNode: + b1 = False + try: + b2 = fx_symbolic_shapes.guard_or_false(sizeB == 1) + except fx_symbolic_shapes.GuardOnDataDependentSymNode: + b2 = False + try: + b3 = fx_symbolic_shapes.guard_or_false(sizeA == sizeB) + except fx_symbolic_shapes.GuardOnDataDependentSymNode: + b3 = False + if b1 or b2 or b3: + expandedSizes[i] = sizeB if fx_symbolic_shapes.guard_or_false(sizeA == 1) else sizeA + else: + # PATCHED: generic case, the dimension is known, no need to assert + expandedSizes[i] = torch.sym_max(sizeA, sizeB) # type: ignore return tuple(expandedSizes) torch: _broadcast_shapes -> patched__broadcast_shapes --- original +++ rewritten @@ -1,12 +1,11 @@ -def _broadcast_shapes(*_shapes): - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - guarding_hint_or_throw, - has_guarding_hint, - is_nested_int, - ) - - backed_so = torch.fx.experimental._config.backed_size_oblivious +def patched__broadcast_shapes(*_shapes): + """ + Patches ``torch._refs._broadcast_shapes``. + This patch is needed to export + :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`. + """ + from functools import reduce + from torch._prims_common import IntLike shapes = tuple( (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes) @@ -19,63 +18,36 @@ for shape in shapes: if not isinstance(shape, Sequence): raise RuntimeError( - "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + "Input shapes should be of type ints, a tuple of ints, " + "or a list of ints, got ", shape, ) # Computes common shape - common_shape: list[int | torch.SymInt] = [ - 1, - ] * reduce(max, (len(shape) for shape in shapes)) - for arg_idx, shape in enumerate(shapes): + common_shape = [1] * reduce(max, (len(shape) for shape in shapes)) + for _arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): - # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1). - if is_nested_int(shape[idx]): + if fx_symbolic_shapes.is_nested_int(shape[idx]): # Broadcasting is allowed for (j0, 1) or (j0, j0); # not (j0, j1), (j0, 5), etc. - if is_nested_int(common_shape[idx]) and guard_or_false( - shape[idx] == common_shape[idx] - ): + if fx_symbolic_shapes.is_nested_int( + common_shape[idx] + ) and fx_symbolic_shapes.guard_or_false(shape[idx] == common_shape[idx]): continue else: - # When backed size oblivious is used, we specialize for broadcasting - # if its the only way to compile the example input. - # i.e: s0:1, s1:1 ==> - # assert s0==s1, no specialization on ==1 or !=1. - # The non-broadcast path is picked - # s0:1, s1:4 ==> - # specialize(s0) to be 1. - # s0:4, s1:1 ==> - # specialize(s1) to be 1. - if ( - backed_so - and has_guarding_hint(shape[idx]) - and has_guarding_hint(common_shape[idx]) - ): - a = guarding_hint_or_throw(shape[idx]) - b = guarding_hint_or_throw(common_shape[idx]) - if a == 1 and b != 1: - torch._check(shape[idx] == 1) - if b == 1 and a != 1: - torch._check(common_shape[idx] == 1) - if guard_or_false(shape[idx] == common_shape[idx]): + if fx_symbolic_shapes.guard_or_false(shape[idx] == common_shape[idx]): continue - - if guard_or_false(common_shape[idx] == 1): + # PATCHED: two cases, if == for sure, no broadcast, + # otherwise maybe broadcast with max(dimensions) + if fx_symbolic_shapes.guard_or_false(common_shape[idx] != 1): + pass + elif fx_symbolic_shapes.guard_or_false( + common_shape[idx] == 1 + ) or fx_symbolic_shapes.guard_or_false(shape[idx] != 1): if shape[idx] < 0: raise ValueError("Attempting to broadcast a dimension with negative length!") - common_shape[idx] = shape[idx] - - if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1): - # broadcast case . - continue + common_shape[idx] = shape[idx] # type: ignore else: - # If broadcasting is undecided we pick non-broadcast path and add runtime assertion. - torch._check( - common_shape[idx] == shape[idx], - lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " - f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " - f"should be broadcastable to {common_shape}", - ) + common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx]) # type: ignore return common_shape torch: _get_range_constraints -> patched__get_range_constraints --- original +++ rewritten @@ -1,42 +1,36 @@ -def _get_range_constraints( +def patched__get_range_constraints( mod: torch.nn.Module, - export_artifact: ExportArtifact, + export_artifact: torch.export._trace.ExportArtifact, args, kwargs, dynamic_shapes, ): + """ + Patches ``torch.export._trace._get_range_constraints``. + See PR `#174593 `_. + """ gm: torch.fx.GraphModule = export_artifact.aten.gm - export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + export_graph_signature: torch.export.graph_signature.ExportGraphSignature = ( + export_artifact.aten.sig + ) fake_mode: FakeTensorMode = export_artifact.fake_mode num_lifted = next( ( i for i, s in enumerate(export_graph_signature.input_specs) - if s.kind == InputKind.USER_INPUT + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT ), len(export_graph_signature.input_specs), ) - combined_args = _combine_args(mod, args, kwargs) - # This is because we trace based on the kwargs passed in from user + # preserve_order=True: + # this is because we trace based on the kwargs passed in from user # not based on the signature. I feel it would be better to just enforce # one ordering at the start of tracing to avoid confusions, but that is # bigger refactor, so do this to unblock for now. - combined_args_traced_order = {} - for arg in combined_args: - if arg not in kwargs: - combined_args_traced_order[arg] = combined_args[arg] + combined_args = _combine_args(mod, args, kwargs, preserve_order=True) - for key in kwargs: - combined_args_traced_order[key] = kwargs[key] - - combined_args = combined_args_traced_order - - range_constraints = make_constraints( - fake_mode, - gm, - combined_args, - dynamic_shapes, - num_lifted, + range_constraints = torch._export.non_strict_utils.make_constraints( + fake_mode, gm, combined_args, dynamic_shapes, num_lifted ) return range_constraints torch: _maybe_broadcast -> patched__maybe_broadcast --- original +++ rewritten @@ -1,15 +1,18 @@ -def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): +def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): + """ + Patches ``torch._refs._maybe_broadcast``. + This patch is needed to export + :class:`yobx.torch.tiny_models.TinyBroadcastAddModel`. + """ + from torch._prims_common import ShapeType, TensorLike, Number + # Computes common shape - common_shape = _broadcast_shapes( + common_shape = patched__broadcast_shapes( *(t.shape if isinstance(t, TensorLike) else None for t in args) ) def should_expand(a: ShapeType, b: ShapeType) -> bool: - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - sym_and, - sym_or, - ) + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and, sym_or if len(a) != len(b): return True @@ -29,10 +32,15 @@ return True # u0==u1 assume the same, no broadcasting! - torch._check( - x == y, - lambda: "sizes assumed to be the same due to unbacked broadcasting semantics", - ) + # PATCHED: avoid errors + return True # guard_or_true(x != y) + # torch._check( + # x == y, + # lambda x=x, y=y: ( + # f"sizes assumed to be the same due to unbacked " + # f"broadcasting semantics x={x!r}, y={y!r}" + # ), + # ) return False @@ -42,14 +50,14 @@ elif isinstance(x, Number): return x elif isinstance(x, TensorLike): - if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): + if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x): # type: ignore return x - if should_expand(x.shape, common_shape): - return x.expand(common_shape) + if should_expand(x.shape, common_shape): # type: ignore + return x.expand(common_shape) # type: ignore return x else: - raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!") + raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!") return tuple(__maybe_broadcast(x, common_shape) for x in args) .. GENERATED FROM PYTHON SOURCE LINES 70-78 3. Plot the diff text as an image ----------------------------------- The first 10 lines of the shortest diff are rendered as a matplotlib figure with colour-coded lines: ``-`` lines in red, ``+`` lines in green, and ``@@`` hunk headers in blue. This makes the figure capturable by sphinx-gallery. :func:`yobx.doc.plot_text` automates this rendering. .. GENERATED FROM PYTHON SOURCE LINES 78-88 .. code-block:: Python import matplotlib.pyplot as plt # noqa: E402 _DIFF_COLORS = {"+": "#2a9d2a", "-": "#cc2222", "@": "#1a6fbf"} smallest = min(details, key=lambda p: len(p.make_diff().splitlines())) diff_preview = "\n".join(smallest.make_diff().splitlines()[:10]) doc.plot_text(diff_preview, title=smallest.name, line_color_map=_DIFF_COLORS) plt.show() .. image-sg:: /auto_examples_torch/images/sphx_glr_plot_patch_model_001.png :alt: _print_Symbol :srcset: /auto_examples_torch/images/sphx_glr_plot_patch_model_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 89-100 4. Show which patches apply when exporting arnir0/Tiny-LLM ----------------------------------------------------------- When exporting a real transformers model we can find out exactly which patched functions were exercised by calling :meth:`PatchDetails.patches_involved_in_graph` after :func:`torch.export.export`. :func:`register_flattening_functions` must also be active so that the :class:`~transformers.DynamicCache` pytree structure is understood by the exporter. .. GENERATED FROM PYTHON SOURCE LINES 100-113 .. code-block:: Python data = get_tiny_model("arnir0/Tiny-LLM") model, inputs, ds = data.model, data.export_inputs, data.dynamic_shapes with ( register_flattening_functions(patch_transformers=True), apply_patches_for_model(patch_torch=True, patch_transformers=True, model=model) as details2, ): ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)) patches = details2.patches_involved_in_graph(ep.graph) print(f"\nPatches involved in the exported graph: {len(patches)}") print(details2.make_report(patches)) .. rst-class:: sphx-glr-script-out .. code-block:: none use_kernel_func_from_hub is not available in the installed kernels version. Please upgrade kernels to use this feature. Patches involved in the exported graph: 1 transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward --- original +++ rewritten @@ -1,18 +1,26 @@ -@torch.no_grad() -@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) -def forward(self, x, position_ids): +@patched_dynamic_rope_update +def forward(self, x, position_ids, layer_type=None): + if layer_type is not None: + # transformers>=5.0 + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + else: + # transformers<5.0 + inv_freq = self.inv_freq + attention_scaling = self.attention_scaling + inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) ) position_ids_expanded = position_ids[:, None, :].float() device_type = ( x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" ) - with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) --- original +++ rewritten @@ -1,103 +1,193 @@ -def dynamic_rope_update(rope_forward): - """ - Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE - (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). +def patched_dynamic_rope_update(rope_forward): + """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]`` - Args: - rope_forward (Callable): - The forward pass of the RoPE implementation. + ``rope_type`` is determined in the constructor of class + :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`. - Returns: - The decorated forward pass. + .. code-block:: python + + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + + The original code of the patched function: + + .. code-block:: python + + def dynamic_rope_update(rope_forward): + def longrope_frequency_update(self, position_ids, device): + seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = + self.config.original_max_position_embeddings + else: + original_max_position_embeddings = + self.config.max_position_embeddings + if seq_len > original_max_position_embeddings: + if not hasattr(self, "long_inv_freq"): + self.long_inv_freq, _ = self.rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + else: + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + + def dynamic_frequency_update(self, position_ids, device): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and + self.max_seq_len_cached > self.original_max_seq_len: + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @wraps(rope_forward) + def wrapper(self, x, position_ids): + if "dynamic" in self.rope_type: + dynamic_frequency_update(self, position_ids, device=x.device) + elif self.rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device) + return rope_forward(self, x, position_ids) + + return wrapper + """ def longrope_frequency_update(self, position_ids, device, layer_type=None): - """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" + # It is no use to patch the function after the model is created + # as rope_init_fn is an attribute set to one function when the model + # is created and when no patch is applied yet. + # So we select the patched version here. + rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) seq_len = torch.max(position_ids) + 1 + if hasattr(self.config, "original_max_position_embeddings"): + original_max_position_embeddings = self.config.original_max_position_embeddings + else: + original_max_position_embeddings = self.config.max_position_embeddings if layer_type is None: - rope_type = self.rope_type - original_inv_freq = self.original_inv_freq - prefix = "" - original_max_position_embeddings = self.config.rope_parameters[ - "original_max_position_embeddings" - ] - else: - rope_type = self.rope_type[layer_type] - original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") - prefix = f"{layer_type}_" - original_max_position_embeddings = self.config.rope_parameters[layer_type][ - "original_max_position_embeddings" - ] - - if seq_len > original_max_position_embeddings: - if not hasattr(self, f"{layer_type}_long_inv_freq"): - rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] - long_inv_freq, _ = rope_init_fn( - self.config, - device, - seq_len=original_max_position_embeddings + 1, - layer_type=layer_type, - ) - self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False) - setattr(self, f"{prefix}long_inv_freq", long_inv_freq) - else: - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - original_inv_freq = original_inv_freq.to(device) - self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) - setattr(self, f"{prefix}original_inv_freq", original_inv_freq) - - def dynamic_frequency_update(self, position_ids, device, layer_type=None): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if layer_type is None: - rope_type = self.rope_type - max_seq_len_cached = self.max_seq_len_cached + # rope_type = self.rope_type original_inv_freq = self.original_inv_freq prefix = "" else: - rope_type = self.rope_type[layer_type] - max_seq_len_cached = getattr( - self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached - ) + # rope_type = self.rope_type[layer_type] original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") prefix = f"{layer_type}_" - if seq_len > max_seq_len_cached: # growth - rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] - inv_freq, self.attention_scaling = rope_init_fn( - self.config, - device, - seq_len=seq_len, - layer_type=layer_type, - ) - # TODO joao: may break with compilation - self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False) - setattr(self, f"{layer_type}_max_seq_len_cached", seq_len) + # At export time, seq_len is unknown. + long_inv_freq, _ = rope_init_fn( + self.config, device, seq_len=original_max_position_embeddings + 1 + ) + original_inv_freq = self.original_inv_freq.to(device) - if ( - seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len - ): # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - original_inv_freq = original_inv_freq.to(device) - self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) - setattr(self, f"{prefix}original_inv_freq", original_inv_freq) - setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len) + # PATCHED: uses torch.cond instead of a test + cond = (seq_len > original_max_position_embeddings).item() + inv_freq = torch.cond( + cond, + (lambda x, y: x.clone()), + (lambda x, y: y.clone()), + [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq], + ) + setattr(self, f"{prefix}inv_freq", inv_freq) + # if seq_len > original_max_position_embeddings: + # self.inv_freq = self.long_inv_freq + # else: + # self.inv_freq = self.original_inv_freq + + def dynamic_frequency_update(self, position_ids, device, layer_type=None): + # constructor: + # - self.max_seq_len_cached = config.max_position_embeddings + # - self.original_max_seq_len = config.max_position_embeddings + # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + + # It is no use to patch the function after the model is created + # as rope_init_fn is an attribute set to one function when the model + # is created and when no patch is applied yet. + # So we select the patched version here. + rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) + + # This behaviour is difficult to translate. + # The sequence always grows. + # The test should always True. + # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len + # + # if seq_len > self.max_seq_len_cached: # growth + # inv_freq, self.attention_scaling = self.rope_init_fn( + # self.config, device, seq_len=seq_len + # ) + # self.register_buffer("inv_freq", inv_freq, persistent=False) + # self.max_seq_len_cached = seq_len + # + # So we should not need what follows. + # + # cond = (seq_len > self.max_seq_len_cached).item() + # self.attention_scaling = torch.cond( + # cond, + # (lambda x, y: x.clone()), + # (lambda x, y: y.clone()), + # [attention_scaling, self.attention_scaling], + # ) + + seq_len = torch.max(position_ids) + 1 + long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len) + + if layer_type is None: + # rope_type = self.rope_type + # max_seq_len_cached = self.max_seq_len_cached + original_inv_freq = self.original_inv_freq + prefix = "" + else: + # rope_type = self.rope_type[layer_type] + # max_seq_len_cached = getattr( + # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached + # ) + original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") + prefix = f"{layer_type}_" + + # Second test to translate. + # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True. + # But in that case the following condition is a way to restore the original cache. + + # if ( + # seq_len < self.original_max_seq_len + # and self.max_seq_len_cached > self.original_max_seq_len + # ): + # self.original_inv_freq = self.original_inv_freq.to(device) + # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + # self.max_seq_len_cached = self.original_max_seq_len + + original_inv_freq = self.original_inv_freq.to(device) + cond = (seq_len >= self.original_max_seq_len).item() + # PATCHED: uses torch.cond instead of a test + inv_freq = torch.cond( + cond, + (lambda x, y: x.clone()), + (lambda x, y: y.clone()), + [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq], + ) + setattr(self, f"{prefix}inv_freq", inv_freq) @wraps(rope_forward) def wrapper(self, x, position_ids, layer_type=None): - rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] - kwargs = {"layer_type": layer_type} if layer_type is not None else {} - if "dynamic" in rope_type: - dynamic_frequency_update(self, position_ids, device=x.device, **kwargs) - elif rope_type == "longrope": - longrope_frequency_update(self, position_ids, device=x.device, **kwargs) - return rope_forward(self, x, position_ids, **kwargs) + if layer_type is None: + if "dynamic" in self.rope_type: + dynamic_frequency_update(self, position_ids, device=x.device) + elif self.rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device) + return rope_forward(self, x, position_ids) + + if "dynamic" in self.rope_type: + dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) + elif self.rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) + return rope_forward(self, x, position_ids, layer_type=layer_type) return wrapper aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze_12 aten.unsqueeze.default(unsqueeze_12, 2) -> unsqueeze_13 aten._assert_tensor_metadata.default(unsqueeze_13) -> _assert_tensor_metadata_default_3 aten.to.dtype(unsqueeze_13, torch.float32) -> to_3 aten.expand.default(to_3, [sym_size_int_18, -1, 1]) -> expand_1 aten._assert_tensor_metadata.default(expand_1) -> _assert_tensor_metadata_default_4 aten.to.dtype_layout(expand_1) -> to_4 aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_4 aten.unsqueeze.default(slice_4, 1) -> unsqueeze_14 aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807) -> slice_5 aten._assert_tensor_metadata.default(slice_5) -> _assert_tensor_metadata_default_5 aten.to.dtype(slice_5, torch.float32) -> to_5 (wrap_with_autocast, 0) -> mul (wrap_with_autocast, 1) -> mul_1 aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8 aten.to.dtype(mul, torch.float32) -> to_8 aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9 aten.to.dtype(mul_1, torch.float32) -> to_9 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.601 seconds) .. _sphx_glr_download_auto_examples_torch_plot_patch_model.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_patch_model.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_patch_model.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_patch_model.zip ` .. include:: plot_patch_model.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_