onnx_diagnostic.torch_export_patches.patch_details¶
- class onnx_diagnostic.torch_export_patches.patch_details.PatchDetails[source][source]¶
This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. Page Patches Diff contains all the diff for all the implemented patches.
<<<
import torch from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] details = PatchDetails() with torch_export_patches( patch_transformers=True, patch_details=details, patch_torch=False ): ep = torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) patches = details.patches_involved_in_graph(ep.graph) report = details.make_report(patches, format="rst") print(report)
>>>
auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original 2+++ rewritten 3@@ -1,18 +1,26 @@ 4-@torch.no_grad() 5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) 6-def forward(self, x, position_ids): 7+@patched_dynamic_rope_update 8+def forward(self, x, position_ids, layer_type=None): 9+ if layer_type is not None: 10+ # transformers>=5.0 11+ inv_freq = getattr(self, f"{layer_type}_inv_freq") 12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling") 13+ else: 14+ # transformers<5.0 15+ inv_freq = self.inv_freq 16+ attention_scaling = self.attention_scaling 17+ 18 inv_freq_expanded = ( 19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) 20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) 21 ) 22 position_ids_expanded = position_ids[:, None, :].float() 23 24 device_type = ( 25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" 26 ) 27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32 28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32 29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 30 emb = torch.cat((freqs, freqs), dim=-1) 31- cos = emb.cos() * self.attention_scaling 32- sin = emb.sin() * self.attention_scaling 33+ cos = emb.cos() * attention_scaling 34+ sin = emb.sin() * attention_scaling 35 36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 37 38--- original 39+++ rewritten 40@@ -1,103 +1,193 @@ 41-def dynamic_rope_update(rope_forward): 42- """ 43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE 44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). 45+def patched_dynamic_rope_update(rope_forward): 46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]`` 47 48- Args: 49- rope_forward (Callable): 50- The forward pass of the RoPE implementation. 51+ ``rope_type`` is determined in the constructor of class 52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`. 53 54- Returns: 55- The decorated forward pass. 56+ .. code-block:: python 57+ 58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None: 59+ self.rope_type = config.rope_scaling.get( 60+ "rope_type", config.rope_scaling.get("type")) 61+ else: 62+ self.rope_type = "default" 63+ 64+ The original code of the patched function: 65+ 66+ .. code-block:: python 67+ 68+ def dynamic_rope_update(rope_forward): 69+ def longrope_frequency_update(self, position_ids, device): 70+ seq_len = torch.max(position_ids) + 1 71+ if hasattr(self.config, "original_max_position_embeddings"): 72+ original_max_position_embeddings = 73+ self.config.original_max_position_embeddings 74+ else: 75+ original_max_position_embeddings = 76+ self.config.max_position_embeddings 77+ if seq_len > original_max_position_embeddings: 78+ if not hasattr(self, "long_inv_freq"): 79+ self.long_inv_freq, _ = self.rope_init_fn( 80+ self.config, device, seq_len=original_max_position_embeddings + 1 81+ ) 82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) 83+ else: 84+ self.original_inv_freq = self.original_inv_freq.to(device) 85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 86+ 87+ def dynamic_frequency_update(self, position_ids, device): 88+ seq_len = torch.max(position_ids) + 1 89+ if seq_len > self.max_seq_len_cached: # growth 90+ inv_freq, self.attention_scaling = self.rope_init_fn( 91+ self.config, device, seq_len=seq_len) 92+ self.register_buffer("inv_freq", inv_freq, persistent=False) 93+ self.max_seq_len_cached = seq_len 94+ 95+ if seq_len < self.original_max_seq_len and 96+ self.max_seq_len_cached > self.original_max_seq_len: 97+ self.original_inv_freq = self.original_inv_freq.to(device) 98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 99+ self.max_seq_len_cached = self.original_max_seq_len 100+ 101+ @wraps(rope_forward) 102+ def wrapper(self, x, position_ids): 103+ if "dynamic" in self.rope_type: 104+ dynamic_frequency_update(self, position_ids, device=x.device) 105+ elif self.rope_type == "longrope": 106+ longrope_frequency_update(self, position_ids, device=x.device) 107+ return rope_forward(self, x, position_ids) 108+ 109+ return wrapper 110+ 111 """ 112 113 def longrope_frequency_update(self, position_ids, device, layer_type=None): 114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" 115+ # It is no use to patch the function after the model is created 116+ # as rope_init_fn is an attribute set to one function when the model 117+ # is created and when no patch is applied yet. 118+ # So we select the patched version here. 119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) 120 seq_len = torch.max(position_ids) + 1 121+ if hasattr(self.config, "original_max_position_embeddings"): 122+ original_max_position_embeddings = self.config.original_max_position_embeddings 123+ else: 124+ original_max_position_embeddings = self.config.max_position_embeddings 125 126 if layer_type is None: 127- rope_type = self.rope_type 128- original_inv_freq = self.original_inv_freq 129- prefix = "" 130- original_max_position_embeddings = self.config.rope_parameters[ 131- "original_max_position_embeddings" 132- ] 133- else: 134- rope_type = self.rope_type[layer_type] 135- original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 136- prefix = f"{layer_type}_" 137- original_max_position_embeddings = self.config.rope_parameters[layer_type][ 138- "original_max_position_embeddings" 139- ] 140- 141- if seq_len > original_max_position_embeddings: 142- if not hasattr(self, f"{layer_type}_long_inv_freq"): 143- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] 144- long_inv_freq, _ = rope_init_fn( 145- self.config, 146- device, 147- seq_len=original_max_position_embeddings + 1, 148- layer_type=layer_type, 149- ) 150- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False) 151- setattr(self, f"{prefix}long_inv_freq", long_inv_freq) 152- else: 153- # This .to() is needed if the model has been moved to a device after being initialized (because 154- # the buffer is automatically moved, but not the original copy) 155- original_inv_freq = original_inv_freq.to(device) 156- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) 157- setattr(self, f"{prefix}original_inv_freq", original_inv_freq) 158- 159- def dynamic_frequency_update(self, position_ids, device, layer_type=None): 160- """ 161- dynamic RoPE layers should recompute `inv_freq` in the following situations: 162- 1 - growing beyond the cached sequence length (allow scaling) 163- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 164- """ 165- seq_len = torch.max(position_ids) + 1 166- if layer_type is None: 167- rope_type = self.rope_type 168- max_seq_len_cached = self.max_seq_len_cached 169+ # rope_type = self.rope_type 170 original_inv_freq = self.original_inv_freq 171 prefix = "" 172 else: 173- rope_type = self.rope_type[layer_type] 174- max_seq_len_cached = getattr( 175- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached 176- ) 177+ # rope_type = self.rope_type[layer_type] 178 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 179 prefix = f"{layer_type}_" 180 181- if seq_len > max_seq_len_cached: # growth 182- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] 183- inv_freq, self.attention_scaling = rope_init_fn( 184- self.config, 185- device, 186- seq_len=seq_len, 187- layer_type=layer_type, 188- ) 189- # TODO joao: may break with compilation 190- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False) 191- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len) 192+ # At export time, seq_len is unknown. 193+ long_inv_freq, _ = rope_init_fn( 194+ self.config, device, seq_len=original_max_position_embeddings + 1 195+ ) 196+ original_inv_freq = self.original_inv_freq.to(device) 197 198- if ( 199- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len 200- ): # reset 201- # This .to() is needed if the model has been moved to a device after being initialized (because 202- # the buffer is automatically moved, but not the original copy) 203- original_inv_freq = original_inv_freq.to(device) 204- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) 205- setattr(self, f"{prefix}original_inv_freq", original_inv_freq) 206- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len) 207+ # PATCHED: uses torch.cond instead of a test 208+ cond = (seq_len > original_max_position_embeddings).item() 209+ inv_freq = torch.cond( 210+ cond, 211+ (lambda x, y: x.clone()), 212+ (lambda x, y: y.clone()), 213+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq], 214+ ) 215+ setattr(self, f"{prefix}inv_freq", inv_freq) 216+ # if seq_len > original_max_position_embeddings: 217+ # self.inv_freq = self.long_inv_freq 218+ # else: 219+ # self.inv_freq = self.original_inv_freq 220+ 221+ def dynamic_frequency_update(self, position_ids, device, layer_type=None): 222+ # constructor: 223+ # - self.max_seq_len_cached = config.max_position_embeddings 224+ # - self.original_max_seq_len = config.max_position_embeddings 225+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 226+ 227+ # It is no use to patch the function after the model is created 228+ # as rope_init_fn is an attribute set to one function when the model 229+ # is created and when no patch is applied yet. 230+ # So we select the patched version here. 231+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) 232+ 233+ # This behaviour is difficult to translate. 234+ # The sequence always grows. 235+ # The test should always True. 236+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len 237+ # 238+ # if seq_len > self.max_seq_len_cached: # growth 239+ # inv_freq, self.attention_scaling = self.rope_init_fn( 240+ # self.config, device, seq_len=seq_len 241+ # ) 242+ # self.register_buffer("inv_freq", inv_freq, persistent=False) 243+ # self.max_seq_len_cached = seq_len 244+ # 245+ # So we should not need what follows. 246+ # 247+ # cond = (seq_len > self.max_seq_len_cached).item() 248+ # self.attention_scaling = torch.cond( 249+ # cond, 250+ # (lambda x, y: x.clone()), 251+ # (lambda x, y: y.clone()), 252+ # [attention_scaling, self.attention_scaling], 253+ # ) 254+ 255+ seq_len = torch.max(position_ids) + 1 256+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len) 257+ 258+ if layer_type is None: 259+ # rope_type = self.rope_type 260+ # max_seq_len_cached = self.max_seq_len_cached 261+ original_inv_freq = self.original_inv_freq 262+ prefix = "" 263+ else: 264+ # rope_type = self.rope_type[layer_type] 265+ # max_seq_len_cached = getattr( 266+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached 267+ # ) 268+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 269+ prefix = f"{layer_type}_" 270+ 271+ # Second test to translate. 272+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True. 273+ # But in that case the following condition is a way to restore the original cache. 274+ 275+ # if ( 276+ # seq_len < self.original_max_seq_len 277+ # and self.max_seq_len_cached > self.original_max_seq_len 278+ # ): 279+ # self.original_inv_freq = self.original_inv_freq.to(device) 280+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 281+ # self.max_seq_len_cached = self.original_max_seq_len 282+ 283+ original_inv_freq = self.original_inv_freq.to(device) 284+ cond = (seq_len >= self.original_max_seq_len).item() 285+ # PATCHED: uses torch.cond instead of a test 286+ inv_freq = torch.cond( 287+ cond, 288+ (lambda x, y: x.clone()), 289+ (lambda x, y: y.clone()), 290+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq], 291+ ) 292+ setattr(self, f"{prefix}inv_freq", inv_freq) 293 294 @wraps(rope_forward) 295 def wrapper(self, x, position_ids, layer_type=None): 296- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] 297- kwargs = {"layer_type": layer_type} if layer_type is not None else {} 298- if "dynamic" in rope_type: 299- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs) 300- elif rope_type == "longrope": 301- longrope_frequency_update(self, position_ids, device=x.device, **kwargs) 302- return rope_forward(self, x, position_ids, **kwargs) 303+ if layer_type is None: 304+ if "dynamic" in self.rope_type: 305+ dynamic_frequency_update(self, position_ids, device=x.device) 306+ elif self.rope_type == "longrope": 307+ longrope_frequency_update(self, position_ids, device=x.device) 308+ return rope_forward(self, x, position_ids) 309+ 310+ if "dynamic" in self.rope_type: 311+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) 312+ elif self.rope_type == "longrope": 313+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) 314+ return rope_forward(self, x, position_ids, layer_type=layer_type) 315 316 return wrapper
impacted nodes
aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze_12 aten.unsqueeze.default(unsqueeze_12, 2) -> unsqueeze_13 aten._assert_tensor_metadata.default(unsqueeze_13) -> _assert_tensor_metadata_default_3 aten.to.dtype(unsqueeze_13, torch.float32) -> to_3 aten.expand.default(to_3, [sym_size_int_19, -1, 1]) -> expand_1 aten._assert_tensor_metadata.default(expand_1) -> _assert_tensor_metadata_default_4 aten.to.dtype_layout(expand_1) -> to_4 aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_4 aten.unsqueeze.default(slice_4, 1) -> unsqueeze_14 aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807) -> slice_5 aten._assert_tensor_metadata.default(slice_5) -> _assert_tensor_metadata_default_5 aten.to.dtype(slice_5, torch.float32) -> to_5 <built-in function getitem>(wrap_with_autocast, 0) -> mul <built-in function getitem>(wrap_with_autocast, 1) -> mul_1 aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8 aten.to.dtype(mul, torch.float32) -> to_8 aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9 aten.to.dtype(mul_1, torch.float32) -> to_9
- append(family: str, function_to_patch: str | Callable, patch: Callable) PatchInfo[source][source]¶
Stores a patch.
- Parameters:
family – a category, anything to classify the patch
function_to_patch – function to patch
patch – function patched
- Returns:
instance of PatchInfo
- make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source][source]¶
Creates a report based on the involved patches.
- Parameters:
patches – from method
patches_involved_in_graph()format – format of the report
- Returns:
report
- matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source][source]¶
Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.
- patches_involved_in_graph(graph: torch.fx.Graph) List[Tuple[PatchInfo, List[torch.fx.Node]]][source][source]¶
Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.
- Parameters:
graph – fx graph
- Returns:
list of nodes impacted by a patch
- class onnx_diagnostic.torch_export_patches.patch_details.PatchInfo(function_to_patch: str | Callable, patch: Callable, family: str = '')[source][source]¶
Stores information about patches.
- Parameters:
function_to_patch – function to patch
patch – function patched
family – a category, anything to classify the patch
- format_diff(format: str = 'raw') str[source][source]¶
Format a diff between two function as a string.
- Parameters:
format –
'raw'or'rst'- Returns:
diff
<<<
import transformers import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched_eager_mask, ) eager_mask = transformers.masking_utils.eager_mask diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst") print(diff)
>>>
eager_mask -> patched_eager_mask¶
1--- original 2+++ rewritten 3@@ -1,4 +1,4 @@ 4-def eager_mask( 5+def patched_eager_mask( 6 batch_size: int, 7 cache_position: torch.Tensor, 8 kv_length: int, 9@@ -6,41 +6,14 @@ 10 mask_function: Callable = causal_mask_function, 11 attention_mask: Optional[torch.Tensor] = None, 12 dtype: torch.dtype = torch.float32, 13- allow_is_bidirectional_skip: bool = False, 14- use_vmap: bool = False, 15 **kwargs, 16 ) -> torch.Tensor: 17- """ 18- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that 19- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that 20- it should not. 21- 22- Args: 23- batch_size (`int`): 24- The batch size of the input sequence. 25- cache_position (`torch.Tensor`): 26- A tensor of shape (query_length,) indicating the current indices of the input sequence elements. 27- kv_length (`int`): 28- The size that the key and value states will have during the attention computation. 29- kv_offset (`int`, optional): 30- An optional offset to indicate at which first position the key and values states will refer to. 31- mask_function (`Callable`): 32- The mask factory function describing the mask pattern. 33- attention_mask (`torch.Tensor`, optional): 34- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) 35- dtype (`torch.dtype`, optional): 36- The dtype to use for the mask. By default, `torch.float32`. 37- allow_is_bidirectional_skip (`bool`, optional): 38- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias, 39- i.e. full attention without any padding. Default to `False`. 40- use_vmap (`bool`, optional): 41- Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be 42- index-based (for the cost of speed performance). By default `False`. 43- """ 44+ """manual patch for function ``transformers.masking_utils.eager_mask``.""" 45 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf 46 _ = kwargs.pop("allow_is_causal_skip", None) 47- _ = kwargs.pop("allow_torch_fix", None) 48- mask = sdpa_mask( 49+ _ = kwargs.pop("allow_is_bidirectional_skip", None) 50+ # PATCHED: this line called the patched version of sdpa_mask 51+ mask = patched_sdpa_mask_recent_torch( 52 batch_size=batch_size, 53 cache_position=cache_position, 54 kv_length=kv_length, 55@@ -48,14 +21,15 @@ 56 mask_function=mask_function, 57 attention_mask=attention_mask, 58 allow_is_causal_skip=False, 59- allow_is_bidirectional_skip=allow_is_bidirectional_skip, 60+ allow_is_bidirectional_skip=False, 61 allow_torch_fix=False, 62- use_vmap=use_vmap, 63 **kwargs, 64 ) 65- # only bidirectional masks can be skipped, otherwise we convert bool -> float 66- if mask is not None: 67- min_dtype = torch.finfo(dtype).min 68- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type) 69- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) 70+ min_dtype = torch.finfo(dtype).min 71+ # PATCHED: the following line 72+ # we need 0s where the tokens should be taken into account, 73+ # and -inf otherwise (mask is already of boolean type) 74+ # mask = 75+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) 76+ mask = (~mask).to(dtype) * min_dtype 77 return mask