onnx_diagnostic.torch_export_patches.patch_details

class onnx_diagnostic.torch_export_patches.patch_details.PatchDetails[source][source]

This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. Page Patches Diff contains all the diff for all the implemented patches.

<<<

import torch
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs

data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
details = PatchDetails()
with torch_export_patches(
    patch_transformers=True, patch_details=details, patch_torch=False
):
    ep = torch.export.export(
        model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
    )
patches = details.patches_involded_in_graph(ep.graph)
report = details.make_report(patches, format="rst")
print(report)

>>>

auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward

  1--- original
  2+++ rewritten
  3@@ -1,18 +1,26 @@
  4-@torch.no_grad()
  5-@dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
  6-def forward(self, x, position_ids):
  7+@patched_dynamic_rope_update
  8+def forward(self, x, position_ids, layer_type=None):
  9+    if layer_type is not None:
 10+        # transformers>=5.0
 11+        inv_freq = getattr(self, f"{layer_type}_inv_freq")
 12+        attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
 13+    else:
 14+        # transformers<5.0
 15+        inv_freq = self.inv_freq
 16+        attention_scaling = self.attention_scaling
 17+
 18     inv_freq_expanded = (
 19-        self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 20+        inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
 21     )
 22     position_ids_expanded = position_ids[:, None, :].float()
 23
 24     device_type = (
 25         x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
 26     )
 27-    with maybe_autocast(device_type=device_type, enabled=False):  # Force float32
 28+    with torch.autocast(device_type=device_type, enabled=False):  # Force float32
 29         freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
 30         emb = torch.cat((freqs, freqs), dim=-1)
 31-        cos = emb.cos() * self.attention_scaling
 32-        sin = emb.sin() * self.attention_scaling
 33+        cos = emb.cos() * attention_scaling
 34+        sin = emb.sin() * attention_scaling
 35
 36     return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 37
 38--- original
 39+++ rewritten
 40@@ -1,99 +1,193 @@
 41-def dynamic_rope_update(rope_forward):
 42-    """
 43-    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
 44-    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
 45+def patched_dynamic_rope_update(rope_forward):
 46+    """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
 47
 48-    Args:
 49-        rope_forward (Callable):
 50-            The forward pass of the RoPE implementation.
 51+    ``rope_type`` is determined in the constructor of class
 52+    :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
 53
 54-    Returns:
 55-        The decorated forward pass.
 56+    .. code-block:: python
 57+
 58+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 59+            self.rope_type = config.rope_scaling.get(
 60+                "rope_type", config.rope_scaling.get("type"))
 61+        else:
 62+            self.rope_type = "default"
 63+
 64+    The original code of the patched function:
 65+
 66+    .. code-block:: python
 67+
 68+        def dynamic_rope_update(rope_forward):
 69+            def longrope_frequency_update(self, position_ids, device):
 70+                seq_len = torch.max(position_ids) + 1
 71+                if hasattr(self.config, "original_max_position_embeddings"):
 72+                    original_max_position_embeddings =
 73+                        self.config.original_max_position_embeddings
 74+                else:
 75+                    original_max_position_embeddings =
 76+                        self.config.max_position_embeddings
 77+                if seq_len > original_max_position_embeddings:
 78+                    if not hasattr(self, "long_inv_freq"):
 79+                        self.long_inv_freq, _ = self.rope_init_fn(
 80+                            self.config, device, seq_len=original_max_position_embeddings + 1
 81+                        )
 82+                    self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
 83+                else:
 84+                    self.original_inv_freq = self.original_inv_freq.to(device)
 85+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 86+
 87+            def dynamic_frequency_update(self, position_ids, device):
 88+                seq_len = torch.max(position_ids) + 1
 89+                if seq_len > self.max_seq_len_cached:  # growth
 90+                    inv_freq, self.attention_scaling = self.rope_init_fn(
 91+                        self.config, device, seq_len=seq_len)
 92+                    self.register_buffer("inv_freq", inv_freq, persistent=False)
 93+                    self.max_seq_len_cached = seq_len
 94+
 95+                if seq_len < self.original_max_seq_len and
 96+                        self.max_seq_len_cached > self.original_max_seq_len:
 97+                    self.original_inv_freq = self.original_inv_freq.to(device)
 98+                    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
 99+                    self.max_seq_len_cached = self.original_max_seq_len
100+
101+            @wraps(rope_forward)
102+            def wrapper(self, x, position_ids):
103+                if "dynamic" in self.rope_type:
104+                    dynamic_frequency_update(self, position_ids, device=x.device)
105+                elif self.rope_type == "longrope":
106+                    longrope_frequency_update(self, position_ids, device=x.device)
107+                return rope_forward(self, x, position_ids)
108+
109+            return wrapper
110+
111     """
112
113     def longrope_frequency_update(self, position_ids, device, layer_type=None):
114-        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
115+        # It is no use to patch the function after the model is created
116+        # as rope_init_fn is an attribute set to one function when the model
117+        # is created and when no patch is applied yet.
118+        # So we select the patched version here.
119+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
120         seq_len = torch.max(position_ids) + 1
121-        original_max_position_embeddings = getattr(
122-            self.config, "original_max_position_embeddings", self.config.max_position_embeddings
123-        )
124+        if hasattr(self.config, "original_max_position_embeddings"):
125+            original_max_position_embeddings = self.config.original_max_position_embeddings
126+        else:
127+            original_max_position_embeddings = self.config.max_position_embeddings
128+
129         if layer_type is None:
130-            rope_type = self.rope_type
131+            # rope_type = self.rope_type
132             original_inv_freq = self.original_inv_freq
133             prefix = ""
134         else:
135-            rope_type = self.rope_type[layer_type]
136+            # rope_type = self.rope_type[layer_type]
137             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
138             prefix = f"{layer_type}_"
139
140-        if seq_len > original_max_position_embeddings:
141-            if not hasattr(self, f"{layer_type}_long_inv_freq"):
142-                rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
143-                long_inv_freq, _ = rope_init_fn(
144-                    self.config,
145-                    device,
146-                    seq_len=original_max_position_embeddings + 1,
147-                    layer_type=layer_type,
148-                )
149-            self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
150-            setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
151-        else:
152-            # This .to() is needed if the model has been moved to a device after being initialized (because
153-            # the buffer is automatically moved, but not the original copy)
154-            original_inv_freq = original_inv_freq.to(device)
155-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
156-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
157+        # At export time, seq_len is unknown.
158+        long_inv_freq, _ = rope_init_fn(
159+            self.config, device, seq_len=original_max_position_embeddings + 1
160+        )
161+        original_inv_freq = self.original_inv_freq.to(device)
162+
163+        # PATCHED: uses torch.cond instead of a test
164+        cond = (seq_len > original_max_position_embeddings).item()
165+        inv_freq = torch.cond(
166+            cond,
167+            (lambda x, y: x.clone()),
168+            (lambda x, y: y.clone()),
169+            [long_inv_freq, original_inv_freq],
170+        )
171+        setattr(self, f"{prefix}inv_freq", inv_freq)
172+        # if seq_len > original_max_position_embeddings:
173+        #    self.inv_freq = self.long_inv_freq
174+        # else:
175+        #    self.inv_freq = self.original_inv_freq
176
177     def dynamic_frequency_update(self, position_ids, device, layer_type=None):
178-        """
179-        dynamic RoPE layers should recompute `inv_freq` in the following situations:
180-        1 - growing beyond the cached sequence length (allow scaling)
181-        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
182-        """
183+        # constructor:
184+        # - self.max_seq_len_cached = config.max_position_embeddings
185+        # - self.original_max_seq_len = config.max_position_embeddings
186+        # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
187+
188+        # It is no use to patch the function after the model is created
189+        # as rope_init_fn is an attribute set to one function when the model
190+        # is created and when no patch is applied yet.
191+        # So we select the patched version here.
192+        rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
193+
194+        # This behaviour is difficult to translate.
195+        # The sequence always grows.
196+        # The test should always True.
197+        # So:  self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
198+        #
199+        # if seq_len > self.max_seq_len_cached:  # growth
200+        #    inv_freq, self.attention_scaling = self.rope_init_fn(
201+        #        self.config, device, seq_len=seq_len
202+        #    )
203+        #    self.register_buffer("inv_freq", inv_freq, persistent=False)
204+        #    self.max_seq_len_cached = seq_len
205+        #
206+        # So we should not need what follows.
207+        #
208+        # cond = (seq_len > self.max_seq_len_cached).item()
209+        # self.attention_scaling = torch.cond(
210+        #    cond,
211+        #    (lambda x, y: x.clone()),
212+        #    (lambda x, y: y.clone()),
213+        #    [attention_scaling, self.attention_scaling],
214+        # )
215+
216         seq_len = torch.max(position_ids) + 1
217+        long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
218+
219         if layer_type is None:
220-            rope_type = self.rope_type
221-            max_seq_len_cached = self.max_seq_len_cached
222+            # rope_type = self.rope_type
223+            # max_seq_len_cached = self.max_seq_len_cached
224             original_inv_freq = self.original_inv_freq
225             prefix = ""
226         else:
227-            rope_type = self.rope_type[layer_type]
228-            max_seq_len_cached = getattr(
229-                self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
230-            )
231+            # rope_type = self.rope_type[layer_type]
232+            # max_seq_len_cached = getattr(
233+            #     self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
234+            # )
235             original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
236             prefix = f"{layer_type}_"
237
238-        if seq_len > max_seq_len_cached:  # growth
239-            rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
240-            inv_freq, self.attention_scaling = rope_init_fn(
241-                self.config,
242-                device,
243-                seq_len=seq_len,
244-                layer_type=layer_type,
245-            )
246-            # TODO joao: may break with compilation
247-            self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
248-            setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
249+        # Second test to translate.
250+        # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
251+        # But in that case the following condition is a way to restore the original cache.
252
253-        if (
254-            seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len
255-        ):  # reset
256-            # This .to() is needed if the model has been moved to a device after being initialized (because
257-            # the buffer is automatically moved, but not the original copy)
258-            original_inv_freq = original_inv_freq.to(device)
259-            self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
260-            setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
261-            setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
262+        # if (
263+        #    seq_len < self.original_max_seq_len
264+        #    and self.max_seq_len_cached > self.original_max_seq_len
265+        # ):
266+        #    self.original_inv_freq = self.original_inv_freq.to(device)
267+        #    self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
268+        #    self.max_seq_len_cached = self.original_max_seq_len
269+
270+        original_inv_freq = self.original_inv_freq.to(device)
271+        cond = (seq_len >= self.original_max_seq_len).item()
272+        # PATCHED: uses torch.cond instead of a test
273+        inv_freq = torch.cond(
274+            cond,
275+            (lambda x, y: x.clone()),
276+            (lambda x, y: y.clone()),
277+            [long_inv_freq, original_inv_freq],
278+        )
279+        setattr(self, f"{prefix}inv_freq", inv_freq)
280
281     @wraps(rope_forward)
282     def wrapper(self, x, position_ids, layer_type=None):
283-        rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
284-        kwargs = {"layer_type": layer_type} if layer_type is not None else {}
285-        if "dynamic" in rope_type:
286-            dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
287-        elif rope_type == "longrope":
288-            longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
289-        return rope_forward(self, x, position_ids, **kwargs)
290+        if layer_type is None:
291+            if "dynamic" in self.rope_type:
292+                dynamic_frequency_update(self, position_ids, device=x.device)
293+            elif self.rope_type == "longrope":
294+                longrope_frequency_update(self, position_ids, device=x.device)
295+            return rope_forward(self, x, position_ids)
296+
297+        if "dynamic" in self.rope_type:
298+            dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
299+        elif self.rope_type == "longrope":
300+            longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type)
301+        return rope_forward(self, x, position_ids, layer_type=layer_type)
302
303     return wrapper

impacted nodes

aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze_12
aten.unsqueeze.default(unsqueeze_12, 2) -> unsqueeze_13
aten._assert_tensor_metadata.default(unsqueeze_13) -> _assert_tensor_metadata_default_3
aten.to.dtype(unsqueeze_13, torch.float32) -> to_3
aten.expand.default(to_3, [sym_size_int_19, -1, 1]) -> expand_1
aten._assert_tensor_metadata.default(expand_1) -> _assert_tensor_metadata_default_4
aten.to.dtype_layout(expand_1) -> to_4
aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_4
aten.unsqueeze.default(slice_4, 1) -> unsqueeze_14
aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807) -> slice_5
aten._assert_tensor_metadata.default(slice_5) -> _assert_tensor_metadata_default_5
aten.to.dtype(slice_5, torch.float32) -> to_5
<built-in function getitem>(wrap_with_autocast, 0) -> mul
<built-in function getitem>(wrap_with_autocast, 1) -> mul_1
aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8
aten.to.dtype(mul, torch.float32) -> to_8
aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9
aten.to.dtype(mul_1, torch.float32) -> to_9
append(family: str, function_to_patch: str | Callable, patch: Callable) PatchInfo[source][source]

Stores a patch.

Parameters:
  • family – a category, anything to classify the patch

  • function_to_patch – function to patch

  • patch – function patched

Returns:

instance of PatchInfo

data() List[Dict[str, Any]][source][source]

Returns the data for a dataframe.

find(name: str) PatchInfo | None[source][source]

Finds a patch by name.

make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source][source]

Creates a report based on the involved patches.

Parameters:
Returns:

report

matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source][source]

Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.

property n_patches: int

Returns the number of stored patches.

patches_involded_in_graph(graph: torch.fx.Graph) List[Tuple[PatchInfo, List[torch.fx.Node]]][source][source]

Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.

Parameters:

graph – fx graph

Returns:

list of nodes impacted by a patch

class onnx_diagnostic.torch_export_patches.patch_details.PatchInfo(function_to_patch: str | Callable, patch: Callable, family: str = '')[source][source]

Stores information about patches.

Parameters:
  • function_to_patch – function to patch

  • patch – function patched

  • family – a category, anything to classify the patch

format_diff(format: str = 'raw') str[source][source]

Format a diff between two function as a string.

Parameters:

format'raw' or 'rst'

Returns:

diff

<<<

import transformers
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
    patched_eager_mask,
)

eager_mask = transformers.masking_utils.eager_mask
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
print(diff)

>>>

eager_mask -> patched_eager_mask

 1--- original
 2+++ rewritten
 3@@ -1,4 +1,4 @@
 4-def eager_mask(
 5+def patched_eager_mask(
 6     batch_size: int,
 7     cache_position: torch.Tensor,
 8     kv_length: int,
 9@@ -6,41 +6,14 @@
10     mask_function: Callable = causal_mask_function,
11     attention_mask: Optional[torch.Tensor] = None,
12     dtype: torch.dtype = torch.float32,
13-    allow_is_bidirectional_skip: bool = False,
14-    use_vmap: bool = False,
15     **kwargs,
16 ) -> torch.Tensor:
17-    """
18-    Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
19-    the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
20-    it should not.
21-
22-    Args:
23-        batch_size (`int`):
24-            The batch size of the input sequence.
25-        cache_position (`torch.Tensor`):
26-            A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
27-        kv_length (`int`):
28-            The size that the key and value states will have during the attention computation.
29-        kv_offset (`int`, optional):
30-            An optional offset to indicate at which first position the key and values states will refer to.
31-        mask_function (`Callable`):
32-            The mask factory function describing the mask pattern.
33-        attention_mask (`torch.Tensor`, optional):
34-            The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
35-        dtype (`torch.dtype`, optional):
36-            The dtype to use for the mask. By default, `torch.float32`.
37-        allow_is_bidirectional_skip (`bool`, optional):
38-            Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
39-            i.e. full attention without any padding. Default to `False`.
40-        use_vmap (`bool`, optional):
41-            Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
42-            index-based (for the cost of speed performance). By default `False`.
43-    """
44+    """manual patch for function ``transformers.masking_utils.eager_mask``."""
45     # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
46     _ = kwargs.pop("allow_is_causal_skip", None)
47-    _ = kwargs.pop("allow_torch_fix", None)
48-    mask = sdpa_mask(
49+    _ = kwargs.pop("allow_is_bidirectional_skip", None)
50+    # PATCHED: this line called the patched version of sdpa_mask
51+    mask = patched_sdpa_mask_recent_torch(
52         batch_size=batch_size,
53         cache_position=cache_position,
54         kv_length=kv_length,
55@@ -48,14 +21,15 @@
56         mask_function=mask_function,
57         attention_mask=attention_mask,
58         allow_is_causal_skip=False,
59-        allow_is_bidirectional_skip=allow_is_bidirectional_skip,
60+        allow_is_bidirectional_skip=False,
61         allow_torch_fix=False,
62-        use_vmap=use_vmap,
63         **kwargs,
64     )
65-    # only bidirectional masks can be skipped, otherwise we convert bool -> float
66-    if mask is not None:
67-        min_dtype = torch.finfo(dtype).min
68-        # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
69-        mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
70+    min_dtype = torch.finfo(dtype).min
71+    # PATCHED: the following line
72+    # we need 0s where the tokens should be taken into account,
73+    # and -inf otherwise (mask is already of boolean type)
74+    # mask =
75+    #   torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
76+    mask = (~mask).to(dtype) * min_dtype
77     return mask
make_diff() str[source][source]

Returns a diff as a string.

to_dict() Dict[str, Any][source][source]

usual

to_tuple() Tuple[str, Callable, Callable][source][source]

usual

onnx_diagnostic.torch_export_patches.patch_details.clean_code_with_black(code: str) str[source][source]

Changes the code style with black if available.

onnx_diagnostic.torch_export_patches.patch_details.make_diff_code(code1: str, code2: str, output: str | None = None) str[source][source]

Creates a diff between two codes.

Parameters:
  • code1 – first code

  • code2 – second code

  • output – if not empty, stores the output in this file

Returns:

diff