onnx_diagnostic.torch_export_patches.patch_details

class onnx_diagnostic.torch_export_patches.patch_details.PatchDetails[source][source]

This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. Page Patches Diff contains all the diff for all the implemented patches.

<<<

import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs

data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
details = PatchDetails()
with torch_export_patches(
    patch_transformers=True, patch_details=details, patch_torch=False
):
    ep = torch.export.export(
        model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
    )
patches = details.patches_involded_in_graph(ep.graph)
report = details.make_report(patches, format="rst")
print(report)

>>>

auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,8 +1,16 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24@@ -12,7 +20,7 @@
 25     with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 26         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 27         emb = torch.cat((freqs, freqs), dim=-1)
 28-        cos = emb.cos() * self.attention_scaling
 29-        sin = emb.sin() * self.attention_scaling
 30+        cos = emb.cos() * attention_scaling
 31+        sin = emb.sin() * attention_scaling
 32
 33     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 34
 35--- original
 36+++ rewritten
 37@@ -1,99 +1,193 @@
 38-def dynamic_rope_update(rope_forward):
 39-    """
 40-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 41-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 42+def patched_dynamic_rope_update(rope_forward):
 43+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 44
 45-    Args:
 46-        rope_forward (Callable):
 47-            The forward pass of the RoPE implementation.
 48+    ``rope_type`` is determined in the constructor of class
 49+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 50
 51-    Returns:
 52-        The decorated forward pass.
 53+    .. code-block:: python
 54+
 55+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 56+            self.rope_type = config.rope_scaling.get(
 57+                "rope_type", config.rope_scaling.get("type"))
 58+        else:
 59+            self.rope_type = "default"
 60+
 61+    The original code of the patched function:
 62+
 63+    .. code-block:: python
 64+
 65+        def dynamic_rope_update(rope_forward):
 66+            def longrope_frequency_update(self, position_ids, device):
 67+                seq_len = torch.max(position_ids) + 1
 68+                if hasattr(self.config, "original_max_position_embeddings"):
 69+                    original_max_position_embeddings =
 70+                        self.config.original_max_position_embeddings
 71+                else:
 72+                    original_max_position_embeddings =
 73+                        self.config.max_position_embeddings
 74+                if seq_len > original_max_position_embeddings:
 75+                    if not hasattr(self, "long_inv_freq"):
 76+                        self.long_inv_freq, _ = self.rope_init_fn(
 77+                            self.config, device, seq_len=original_max_position_embeddings + 1
 78+                        )
 79+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 80+                else:
 81+                    self.original_inv_freq = self.original_inv_freq.to(device)
 82+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 83+
 84+            def dynamic_frequency_update(self, position_ids, device):
 85+                seq_len = torch.max(position_ids) + 1
 86+                if seq_len > self.max_seq_len_cached:  # growth
 87+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 88+                        self.config, device, seq_len=seq_len)
 89+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 90+                    self.max_seq_len_cached = seq_len
 91+
 92+                if seq_len < self.original_max_seq_len and
 93+                        self.max_seq_len_cached > self.original_max_seq_len:
 94+                    self.original_inv_freq = self.original_inv_freq.to(device)
 95+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 96+                    self.max_seq_len_cached = self.original_max_seq_len
 97+
 98+            @wraps(rope_forward)
 99+            def wrapper(self, x, position_ids):
100+                if "dynamic" in self.rope_type:
101+                    dynamic_frequency_update(self, position_ids, device=x.device)
102+                elif self.rope_type == "longrope":
103+                    longrope_frequency_update(self, position_ids, device=x.device)
104+                return rope_forward(self, x, position_ids)
105+
106+            return wrapper
107+
108     """
109
110     def longrope_frequency_update(self, position_ids, device, layer_type=None):
111-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
112+        # It is no use to patch the function after the model is created
113+        # as rope_init_fn is an attribute set to one function when the model
114+        # is created and when no patch is applied yet.
115+        # So we select the patched version here.
116+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
117         seq_len = torch.max(position_ids) + 1
118-        original_max_position_embeddings = getattr(
119-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
120-        )
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125+
126         if layer_type is None:
127-            rope_type = self.rope_type
128+            # rope_type = self.rope_type
129             original_inv_freq = self.original_inv_freq
130             prefix = ""
131         else:
132-            rope_type = self.rope_type[layer_type]
133+            # rope_type = self.rope_type[layer_type]
134             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
135             prefix = f"{layer_type}_"
136
137-        if seq_len > original_max_position_embeddings:
138-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
139-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
140-                long_inv_freq, _ = rope_init_fn(
141-                    self.config,
142-                    device,
143-                    seq_len=original_max_position_embeddings + 1,
144-                    layer_type=layer_type,
145-                )
146-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
147-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
148-        else:
149-            # This .to() is needed if the model has been moved to a device after being initialized (because
150-            # the buffer is automatically moved, but not the original copy)
151-            original_inv_freq = original_inv_freq.to(device)
152-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
153-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
154+        # At export time, seq_len is unknown.
155+        long_inv_freq, _ = rope_init_fn(
156+            self.config, device, seq_len=original_max_position_embeddings + 1
157+        )
158+        original_inv_freq = self.original_inv_freq.to(device)
159+
160+        # PATCHED: uses torch.cond instead of a test
161+        cond = (seq_len > original_max_position_embeddings).item()
162+        inv_freq = torch.cond(
163+            cond,
164+            (lambda x, y: x.clone()),
165+            (lambda x, y: y.clone()),
166+            [long_inv_freq, original_inv_freq],
167+        )
168+        setattr(self, f"{prefix}inv_freq", inv_freq)
169+        # if seq_len > original_max_position_embeddings:
170+        #    self.inv_freq = self.long_inv_freq
171+        # else:
172+        #    self.inv_freq = self.original_inv_freq
173
174     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
175-        """
176-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
177-        1 - growing beyond the cached sequence length (allow scaling)
178-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
179-        """
180+        # constructor:
181+        # - self.max_seq_len_cached = config.max_position_embeddings
182+        # - self.original_max_seq_len = config.max_position_embeddings
183+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
184+
185+        # It is no use to patch the function after the model is created
186+        # as rope_init_fn is an attribute set to one function when the model
187+        # is created and when no patch is applied yet.
188+        # So we select the patched version here.
189+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
190+
191+        # This behaviour is difficult to translate.
192+        # The sequence always grows.
193+        # The test should always True.
194+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
195+        #
196+        # if seq_len > self.max_seq_len_cached:  # growth
197+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
198+        #        self.config, device, seq_len=seq_len
199+        #    )
200+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
201+        #    self.max_seq_len_cached = seq_len
202+        #
203+        # So we should not need what follows.
204+        #
205+        # cond = (seq_len > self.max_seq_len_cached).item()
206+        # self.attention_scaling = torch.cond(
207+        #    cond,
208+        #    (lambda x, y: x.clone()),
209+        #    (lambda x, y: y.clone()),
210+        #    [attention_scaling, self.attention_scaling],
211+        # )
212+
213         seq_len = torch.max(position_ids) + 1
214+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
215+
216         if layer_type is None:
217-            rope_type = self.rope_type
218-            max_seq_len_cached = self.max_seq_len_cached
219+            # rope_type = self.rope_type
220+            # max_seq_len_cached = self.max_seq_len_cached
221             original_inv_freq = self.original_inv_freq
222             prefix = ""
223         else:
224-            rope_type = self.rope_type[layer_type]
225-            max_seq_len_cached = getattr(
226-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
227-            )
228+            # rope_type = self.rope_type[layer_type]
229+            # max_seq_len_cached = getattr(
230+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
231+            # )
232             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
233             prefix = f"{layer_type}_"
234
235-        if seq_len > max_seq_len_cached:  # growth
236-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
237-            inv_freq, self.attention_scaling = rope_init_fn(
238-                self.config,
239-                device,
240-                seq_len=seq_len,
241-                layer_type=layer_type,
242-            )
243-            # TODO joao: may break with compilation
244-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
245-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
246+        # Second test to translate.
247+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
248+        # But in that case the following condition is a way to restore the original cache.
249
250-        if (
251-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
252-        ):  # reset
253-            # This .to() is needed if the model has been moved to a device after being initialized (because
254-            # the buffer is automatically moved, but not the original copy)
255-            original_inv_freq = original_inv_freq.to(device)
256-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
257-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
258-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
259+        # if (
260+        #    seq_len < self.original_max_seq_len
261+        #    and self.max_seq_len_cached > self.original_max_seq_len
262+        # ):
263+        #    self.original_inv_freq = self.original_inv_freq.to(device)
264+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
265+        #    self.max_seq_len_cached = self.original_max_seq_len
266+
267+        original_inv_freq = self.original_inv_freq.to(device)
268+        cond = (seq_len >= self.original_max_seq_len).item()
269+        # PATCHED: uses torch.cond instead of a test
270+        inv_freq = torch.cond(
271+            cond,
272+            (lambda x, y: x.clone()),
273+            (lambda x, y: y.clone()),
274+            [long_inv_freq, original_inv_freq],
275+        )
276+        setattr(self, f"{prefix}inv_freq", inv_freq)
277
278     @wraps(rope_forward)
279     def wrapper(self, x, position_ids, layer_type=None):
280-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
281-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
282-        if "dynamic" in rope_type:
283-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
284-        elif rope_type == "longrope":
285-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
286-        return rope_forward(self, x, position_ids, **kwargs)
287+        if layer_type is None:
288+            if "dynamic" in self.rope_type:
289+                dynamic_frequency_update(self, position_ids, device=x.device)
290+            elif self.rope_type == "longrope":
291+                longrope_frequency_update(self, position_ids, device=x.device)
292+            return rope_forward(self, x, position_ids)
293+
294+        if "dynamic" in self.rope_type:
295+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
296+        elif self.rope_type == "longrope":
297+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
298+        return rope_forward(self, x, position_ids, layer_type=layer_type)
299
300     return wrapper

impacted nodes

aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze
aten.unsqueeze.default(unsqueeze, 2) -> unsqueeze_1
aten._assert_tensor_metadata.default(unsqueeze_1) -> _assert_tensor_metadata_default_3
aten.to.dtype(unsqueeze_1, torch.float32) -> to_3
aten.expand.default(to_3, [sym_size_int_19, -1, 1]) -> expand_4
aten._assert_tensor_metadata.default(expand_4) -> _assert_tensor_metadata_default_4
aten.to.dtype_layout(expand_4) -> to_4
aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_1
aten.unsqueeze.default(slice_1, 1) -> unsqueeze_2
aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807) -> slice_2
aten._assert_tensor_metadata.default(slice_2) -> _assert_tensor_metadata_default_5
aten.to.dtype(slice_2, torch.float32) -> to_5
<built-in function getitem>(wrap_with_autocast, 0) -> mul
<built-in function getitem>(wrap_with_autocast, 1) -> mul_1
aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8
aten.to.dtype(mul, torch.float32) -> to_8
aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9
aten.to.dtype(mul_1, torch.float32) -> to_9
append(family: str, function_to_patch: str | Callable, patch: Callable) PatchInfo[source][source]

Stores a patch.

Parameters:
  • family – a category, anything to classify the patch

  • function_to_patch – function to patch

  • patch – function patched

Returns:

instance of PatchInfo

data() List[Dict[str, Any]][source][source]

Returns the data for a dataframe.

find(name: str) PatchInfo | None[source][source]

Finds a patch by name.

make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source][source]

Creates a report based on the involved patches.

Parameters:
Returns:

report

matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source][source]

Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.

property n_patches: int

Returns the number of stored patches.

patches_involded_in_graph(graph: torch.fx.Graph) List[Tuple[PatchInfo, List[torch.fx.Node]]][source][source]

Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.

Parameters:

graph – fx graph

Returns:

list of nodes impacted by a patch

class onnx_diagnostic.torch_export_patches.patch_details.PatchInfo(function_to_patch: str | Callable, patch: Callable, family: str = '')[source][source]

Stores information about patches.

Parameters:
  • function_to_patch – function to patch

  • patch – function patched

  • family – a category, anything to classify the patch

format_diff(format: str = 'raw') str[source][source]

Format a diff between two function as a string.

Parameters:

format'raw' or 'rst'

Returns:

diff

<<<

import transformers
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
    patched_eager_mask,
)

eager_mask = transformers.masking_utils.eager_mask
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
print(diff)

>>>

eager_mask -> patched_eager_mask

 1--- original
 2+++ rewritten
 3@@ -1,4 +1,4 @@
 4-def eager_mask(
 5+def patched_eager_mask(
 6     batch_size: int,
 7     cache_position: torch.Tensor,
 8     kv_length: int,
 9@@ -8,31 +8,12 @@
10     dtype: torch.dtype = torch.float32,
11     **kwargs,
12 ) -> torch.Tensor:
13-    """
14-    Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
15-    the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
16-    it should not.
17-
18-    Args:
19-        batch_size (`int`):
20-            The batch size of the input sequence.
21-        cache_position (`torch.Tensor`):
22-            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
23-        kv_length (`int`):
24-            The size that the key and value states will have during the attention computation.
25-        kv_offset (`int`, optional):
26-            An optional offset to indicate at which first position the key and values states will refer to.
27-        mask_function (`Callable`):
28-            The mask factory function describing the mask pattern.
29-        attention_mask (`torch.Tensor`, optional):
30-            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
31-        dtype (`torch.dtype`, optional):
32-            The dtype to use for the mask. By default, `torch.float32`.
33-    """
34+    """manual patch for function ``transformers.masking_utils.eager_mask``."""
35     # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
36     _ = kwargs.pop("allow_is_causal_skip", None)
37     _ = kwargs.pop("allow_is_bidirectional_skip", None)
38-    mask = sdpa_mask(
39+    # PATCHED: this line called the patched version of sdpa_mask
40+    mask = patched_sdpa_mask_recent_torch(
41         batch_size=batch_size,
42         cache_position=cache_position,
43         kv_length=kv_length,
44@@ -45,6 +26,10 @@
45         **kwargs,
46     )
47     min_dtype = torch.finfo(dtype).min
48-    # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
49-    mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
50+    # PATCHED: the following line
51+    # we need 0s where the tokens should be taken into account,
52+    # and -inf otherwise (mask is already of boolean type)
53+    # mask =
54+    #   torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
55+    mask = (~mask).to(dtype) * min_dtype
56     return mask
make_diff() str[source][source]

Returns a diff as a string.

to_dict() Dict[str, Any][source][source]

usual

to_tuple() Tuple[str, Callable, Callable][source][source]

usual

onnx_diagnostic.torch_export_patches.patch_details.clean_code_with_black(code: str) str[source][source]

Changes the code style with black if available.

onnx_diagnostic.torch_export_patches.patch_details.make_diff_code(code1: str, code2: str, output: str | None = None) str[source][source]

Creates a diff between two codes.

Parameters:
  • code1 – first code

  • code2 – second code

  • output – if not empty, stores the output in this file

Returns:

diff