onnx_diagnostic.torch_export_patches.patch_details

class onnx_diagnostic.torch_export_patches.patch_details.PatchDetails[source][source]

This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. Page Patches Diff contains all the diff for all the implemented patches.

<<<

import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs

data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
details = PatchDetails()
with torch_export_patches(
    patch_transformers=True, patch_details=details, patch_torch=False
):
    ep = torch.export.export(
        model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
    )
patches = details.patches_involved_in_graph(ep.graph)
report = details.make_report(patches, format="rst")
print(report)

>>>

auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,18 +1,26 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24     device_type = (
 25         x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
 26     )
 27-    with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
 28+    with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 29         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 30         emb = torch.cat((freqs, freqs), dim=-1)
 31-        cos = emb.cos() * self.attention_scaling
 32-        sin = emb.sin() * self.attention_scaling
 33+        cos = emb.cos() * attention_scaling
 34+        sin = emb.sin() * attention_scaling
 35
 36     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 37
 38--- original
 39+++ rewritten
 40@@ -1,103 +1,193 @@
 41-def dynamic_rope_update(rope_forward):
 42-    """
 43-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 44-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 45+def patched_dynamic_rope_update(rope_forward):
 46+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 47
 48-    Args:
 49-        rope_forward (Callable):
 50-            The forward pass of the RoPE implementation.
 51+    ``rope_type`` is determined in the constructor of class
 52+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 53
 54-    Returns:
 55-        The decorated forward pass.
 56+    .. code-block:: python
 57+
 58+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 59+            self.rope_type = config.rope_scaling.get(
 60+                "rope_type", config.rope_scaling.get("type"))
 61+        else:
 62+            self.rope_type = "default"
 63+
 64+    The original code of the patched function:
 65+
 66+    .. code-block:: python
 67+
 68+        def dynamic_rope_update(rope_forward):
 69+            def longrope_frequency_update(self, position_ids, device):
 70+                seq_len = torch.max(position_ids) + 1
 71+                if hasattr(self.config, "original_max_position_embeddings"):
 72+                    original_max_position_embeddings =
 73+                        self.config.original_max_position_embeddings
 74+                else:
 75+                    original_max_position_embeddings =
 76+                        self.config.max_position_embeddings
 77+                if seq_len > original_max_position_embeddings:
 78+                    if not hasattr(self, "long_inv_freq"):
 79+                        self.long_inv_freq, _ = self.rope_init_fn(
 80+                            self.config, device, seq_len=original_max_position_embeddings + 1
 81+                        )
 82+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 83+                else:
 84+                    self.original_inv_freq = self.original_inv_freq.to(device)
 85+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 86+
 87+            def dynamic_frequency_update(self, position_ids, device):
 88+                seq_len = torch.max(position_ids) + 1
 89+                if seq_len > self.max_seq_len_cached:  # growth
 90+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 91+                        self.config, device, seq_len=seq_len)
 92+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 93+                    self.max_seq_len_cached = seq_len
 94+
 95+                if seq_len < self.original_max_seq_len and
 96+                        self.max_seq_len_cached > self.original_max_seq_len:
 97+                    self.original_inv_freq = self.original_inv_freq.to(device)
 98+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 99+                    self.max_seq_len_cached = self.original_max_seq_len
100+
101+            @wraps(rope_forward)
102+            def wrapper(self, x, position_ids):
103+                if "dynamic" in self.rope_type:
104+                    dynamic_frequency_update(self, position_ids, device=x.device)
105+                elif self.rope_type == "longrope":
106+                    longrope_frequency_update(self, position_ids, device=x.device)
107+                return rope_forward(self, x, position_ids)
108+
109+            return wrapper
110+
111     """
112
113     def longrope_frequency_update(self, position_ids, device, layer_type=None):
114-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+        # It is no use to patch the function after the model is created
116+        # as rope_init_fn is an attribute set to one function when the model
117+        # is created and when no patch is applied yet.
118+        # So we select the patched version here.
119+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120         seq_len = torch.max(position_ids) + 1
121+        if hasattr(self.config, "original_max_position_embeddings"):
122+            original_max_position_embeddings = self.config.original_max_position_embeddings
123+        else:
124+            original_max_position_embeddings = self.config.max_position_embeddings
125
126         if layer_type is None:
127-            rope_type = self.rope_type
128-            original_inv_freq = self.original_inv_freq
129-            prefix = ""
130-            original_max_position_embeddings = self.config.rope_parameters[
131-                "original_max_position_embeddings"
132-            ]
133-        else:
134-            rope_type = self.rope_type[layer_type]
135-            original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
136-            prefix = f"{layer_type}_"
137-            original_max_position_embeddings = self.config.rope_parameters[layer_type][
138-                "original_max_position_embeddings"
139-            ]
140-
141-        if seq_len > original_max_position_embeddings:
142-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
143-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
144-                long_inv_freq, _ = rope_init_fn(
145-                    self.config,
146-                    device,
147-                    seq_len=original_max_position_embeddings + 1,
148-                    layer_type=layer_type,
149-                )
150-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
151-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
152-        else:
153-            # This .to() is needed if the model has been moved to a device after being initialized (because
154-            # the buffer is automatically moved, but not the original copy)
155-            original_inv_freq = original_inv_freq.to(device)
156-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
157-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
158-
159-    def dynamic_frequency_update(self, position_ids, device, layer_type=None):
160-        """
161-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
162-        1 - growing beyond the cached sequence length (allow scaling)
163-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
164-        """
165-        seq_len = torch.max(position_ids) + 1
166-        if layer_type is None:
167-            rope_type = self.rope_type
168-            max_seq_len_cached = self.max_seq_len_cached
169+            # rope_type = self.rope_type
170             original_inv_freq = self.original_inv_freq
171             prefix = ""
172         else:
173-            rope_type = self.rope_type[layer_type]
174-            max_seq_len_cached = getattr(
175-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
176-            )
177+            # rope_type = self.rope_type[layer_type]
178             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
179             prefix = f"{layer_type}_"
180
181-        if seq_len > max_seq_len_cached:  # growth
182-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
183-            inv_freq, self.attention_scaling = rope_init_fn(
184-                self.config,
185-                device,
186-                seq_len=seq_len,
187-                layer_type=layer_type,
188-            )
189-            # TODO joao: may break with compilation
190-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
191-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
192+        # At export time, seq_len is unknown.
193+        long_inv_freq, _ = rope_init_fn(
194+            self.config, device, seq_len=original_max_position_embeddings + 1
195+        )
196+        original_inv_freq = self.original_inv_freq.to(device)
197
198-        if (
199-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
200-        ):  # reset
201-            # This .to() is needed if the model has been moved to a device after being initialized (because
202-            # the buffer is automatically moved, but not the original copy)
203-            original_inv_freq = original_inv_freq.to(device)
204-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
205-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
206-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
207+        # PATCHED: uses torch.cond instead of a test
208+        cond = (seq_len > original_max_position_embeddings).item()
209+        inv_freq = torch.cond(
210+            cond,
211+            (lambda x, y: x.clone()),
212+            (lambda x, y: y.clone()),
213+            [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
214+        )
215+        setattr(self, f"{prefix}inv_freq", inv_freq)
216+        # if seq_len > original_max_position_embeddings:
217+        #    self.inv_freq = self.long_inv_freq
218+        # else:
219+        #    self.inv_freq = self.original_inv_freq
220+
221+    def dynamic_frequency_update(self, position_ids, device, layer_type=None):
222+        # constructor:
223+        # - self.max_seq_len_cached = config.max_position_embeddings
224+        # - self.original_max_seq_len = config.max_position_embeddings
225+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
226+
227+        # It is no use to patch the function after the model is created
228+        # as rope_init_fn is an attribute set to one function when the model
229+        # is created and when no patch is applied yet.
230+        # So we select the patched version here.
231+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
232+
233+        # This behaviour is difficult to translate.
234+        # The sequence always grows.
235+        # The test should always True.
236+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
237+        #
238+        # if seq_len > self.max_seq_len_cached:  # growth
239+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
240+        #        self.config, device, seq_len=seq_len
241+        #    )
242+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
243+        #    self.max_seq_len_cached = seq_len
244+        #
245+        # So we should not need what follows.
246+        #
247+        # cond = (seq_len > self.max_seq_len_cached).item()
248+        # self.attention_scaling = torch.cond(
249+        #    cond,
250+        #    (lambda x, y: x.clone()),
251+        #    (lambda x, y: y.clone()),
252+        #    [attention_scaling, self.attention_scaling],
253+        # )
254+
255+        seq_len = torch.max(position_ids) + 1
256+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
257+
258+        if layer_type is None:
259+            # rope_type = self.rope_type
260+            # max_seq_len_cached = self.max_seq_len_cached
261+            original_inv_freq = self.original_inv_freq
262+            prefix = ""
263+        else:
264+            # rope_type = self.rope_type[layer_type]
265+            # max_seq_len_cached = getattr(
266+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
267+            # )
268+            original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
269+            prefix = f"{layer_type}_"
270+
271+        # Second test to translate.
272+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
273+        # But in that case the following condition is a way to restore the original cache.
274+
275+        # if (
276+        #    seq_len < self.original_max_seq_len
277+        #    and self.max_seq_len_cached > self.original_max_seq_len
278+        # ):
279+        #    self.original_inv_freq = self.original_inv_freq.to(device)
280+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
281+        #    self.max_seq_len_cached = self.original_max_seq_len
282+
283+        original_inv_freq = self.original_inv_freq.to(device)
284+        cond = (seq_len >= self.original_max_seq_len).item()
285+        # PATCHED: uses torch.cond instead of a test
286+        inv_freq = torch.cond(
287+            cond,
288+            (lambda x, y: x.clone()),
289+            (lambda x, y: y.clone()),
290+            [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
291+        )
292+        setattr(self, f"{prefix}inv_freq", inv_freq)
293
294     @wraps(rope_forward)
295     def wrapper(self, x, position_ids, layer_type=None):
296-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
297-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
298-        if "dynamic" in rope_type:
299-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
300-        elif rope_type == "longrope":
301-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
302-        return rope_forward(self, x, position_ids, **kwargs)
303+        if layer_type is None:
304+            if "dynamic" in self.rope_type:
305+                dynamic_frequency_update(self, position_ids, device=x.device)
306+            elif self.rope_type == "longrope":
307+                longrope_frequency_update(self, position_ids, device=x.device)
308+            return rope_forward(self, x, position_ids)
309+
310+        if "dynamic" in self.rope_type:
311+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
312+        elif self.rope_type == "longrope":
313+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
314+        return rope_forward(self, x, position_ids, layer_type=layer_type)
315
316     return wrapper

impacted nodes

aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze_12
aten.unsqueeze.default(unsqueeze_12, 2) -> unsqueeze_13
aten._assert_tensor_metadata.default(unsqueeze_13) -> _assert_tensor_metadata_default_3
aten.to.dtype(unsqueeze_13, torch.float32) -> to_3
aten.expand.default(to_3, [sym_size_int_19, -1, 1]) -> expand_1
aten._assert_tensor_metadata.default(expand_1) -> _assert_tensor_metadata_default_4
aten.to.dtype_layout(expand_1) -> to_4
aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_4
aten.unsqueeze.default(slice_4, 1) -> unsqueeze_14
aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807) -> slice_5
aten._assert_tensor_metadata.default(slice_5) -> _assert_tensor_metadata_default_5
aten.to.dtype(slice_5, torch.float32) -> to_5
<built-in function getitem>(wrap_with_autocast, 0) -> mul
<built-in function getitem>(wrap_with_autocast, 1) -> mul_1
aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8
aten.to.dtype(mul, torch.float32) -> to_8
aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9
aten.to.dtype(mul_1, torch.float32) -> to_9
append(family: str, function_to_patch: str | Callable, patch: Callable) PatchInfo[source][source]

Stores a patch.

Parameters:
  • family – a category, anything to classify the patch

  • function_to_patch – function to patch

  • patch – function patched

Returns:

instance of PatchInfo

data() List[Dict[str, Any]][source][source]

Returns the data for a dataframe.

find(name: str) PatchInfo | None[source][source]

Finds a patch by name.

make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source][source]

Creates a report based on the involved patches.

Parameters:
Returns:

report

matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source][source]

Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.

property n_patches: int

Returns the number of stored patches.

patches_involved_in_graph(graph: torch.fx.Graph) List[Tuple[PatchInfo, List[torch.fx.Node]]][source][source]

Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.

Parameters:

graph – fx graph

Returns:

list of nodes impacted by a patch

class onnx_diagnostic.torch_export_patches.patch_details.PatchInfo(function_to_patch: str | Callable, patch: Callable, family: str = '')[source][source]

Stores information about patches.

Parameters:
  • function_to_patch – function to patch

  • patch – function patched

  • family – a category, anything to classify the patch

format_diff(format: str = 'raw') str[source][source]

Format a diff between two function as a string.

Parameters:

format'raw' or 'rst'

Returns:

diff

<<<

import transformers
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
    patched_eager_mask,
)

eager_mask = transformers.masking_utils.eager_mask
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
print(diff)

>>>

eager_mask -> patched_eager_mask

 1--- original
 2+++ rewritten
 3@@ -1,46 +1,19 @@
 4-def eager_mask(
 5+def patched_eager_mask(
 6     batch_size: int,
 7     cache_position: torch.Tensor,
 8     kv_length: int,
 9     kv_offset: int = 0,
10     mask_function: Callable = causal_mask_function,
11-    attention_mask: torch.Tensor | None = None,
12+    attention_mask: Optional[torch.Tensor] = None,
13     dtype: torch.dtype = torch.float32,
14-    allow_is_bidirectional_skip: bool = False,
15-    use_vmap: bool = False,
16     **kwargs,
17 ) -> torch.Tensor:
18-    """
19-    Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
20-    the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
21-    it should not.
22-
23-    Args:
24-        batch_size (`int`):
25-            The batch size of the input sequence.
26-        cache_position (`torch.Tensor`):
27-            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
28-        kv_length (`int`):
29-            The size that the key and value states will have during the attention computation.
30-        kv_offset (`int`, optional):
31-            An optional offset to indicate at which first position the key and values states will refer to.
32-        mask_function (`Callable`):
33-            The mask factory function describing the mask pattern.
34-        attention_mask (`torch.Tensor`, optional):
35-            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
36-        dtype (`torch.dtype`, optional):
37-            The dtype to use for the mask. By default, `torch.float32`.
38-        allow_is_bidirectional_skip (`bool`, optional):
39-            Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
40-            i.e. full attention without any padding. Default to `False`.
41-        use_vmap (`bool`, optional):
42-            Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
43-            index-based (for the cost of speed performance). By default `False`.
44-    """
45+    """manual patch for function ``transformers.masking_utils.eager_mask``."""
46     # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
47     _ = kwargs.pop("allow_is_causal_skip", None)
48-    _ = kwargs.pop("allow_torch_fix", None)
49-    mask = sdpa_mask(
50+    _ = kwargs.pop("allow_is_bidirectional_skip", None)
51+    # PATCHED: this line called the patched version of sdpa_mask
52+    mask = patched_sdpa_mask_recent_torch(
53         batch_size=batch_size,
54         cache_position=cache_position,
55         kv_length=kv_length,
56@@ -48,14 +21,15 @@
57         mask_function=mask_function,
58         attention_mask=attention_mask,
59         allow_is_causal_skip=False,
60-        allow_is_bidirectional_skip=allow_is_bidirectional_skip,
61+        allow_is_bidirectional_skip=False,
62         allow_torch_fix=False,
63-        use_vmap=use_vmap,
64         **kwargs,
65     )
66-    # only bidirectional masks can be skipped, otherwise we convert bool -> float
67-    if mask is not None:
68-        min_dtype = torch.finfo(dtype).min
69-        # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
70-        mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
71+    min_dtype = torch.finfo(dtype).min
72+    # PATCHED: the following line
73+    # we need 0s where the tokens should be taken into account,
74+    # and -inf otherwise (mask is already of boolean type)
75+    # mask =
76+    #   torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
77+    mask = (~mask).to(dtype) * min_dtype
78     return mask
make_diff() str[source][source]

Returns a diff as a string.

to_dict() Dict[str, Any][source][source]

usual

to_tuple() Tuple[str, Callable, Callable][source][source]

usual

onnx_diagnostic.torch_export_patches.patch_details.clean_code_with_black(code: str) str[source][source]

Changes the code style with black if available.

onnx_diagnostic.torch_export_patches.patch_details.make_diff_code(code1: str, code2: str, output: str | None = None) str[source][source]

Creates a diff between two codes.

Parameters:
  • code1 – first code

  • code2 – second code

  • output – if not empty, stores the output in this file

Returns:

diff