onnx_diagnostic.torch_export_patches.patch_details¶
- class onnx_diagnostic.torch_export_patches.patch_details.PatchDetails[source][source]¶
This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. Page Patches Diff contains all the diff for all the implemented patches.
<<<
import torch from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] details = PatchDetails() with torch_export_patches( patch_transformers=True, patch_details=details, patch_torch=False ): ep = torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) patches = details.patches_involded_in_graph(ep.graph) report = details.make_report(patches, format="rst") print(report)
>>>
auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original 2+++ rewritten 3@@ -1,8 +1,16 @@ 4-@torch.no_grad() 5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) 6-def forward(self, x, position_ids): 7+@patched_dynamic_rope_update 8+def forward(self, x, position_ids, layer_type=None): 9+ if layer_type is not None: 10+ # transformers>=5.0 11+ inv_freq = getattr(self, f"{layer_type}_inv_freq") 12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling") 13+ else: 14+ # transformers<5.0 15+ inv_freq = self.inv_freq 16+ attention_scaling = self.attention_scaling 17+ 18 inv_freq_expanded = ( 19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) 20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) 21 ) 22 position_ids_expanded = position_ids[:, None, :].float() 23 24@@ -12,7 +20,7 @@ 25 with torch.autocast(device_type=device_type, enabled=False): # Force float32 26 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 27 emb = torch.cat((freqs, freqs), dim=-1) 28- cos = emb.cos() * self.attention_scaling 29- sin = emb.sin() * self.attention_scaling 30+ cos = emb.cos() * attention_scaling 31+ sin = emb.sin() * attention_scaling 32 33 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 34 35--- original 36+++ rewritten 37@@ -1,99 +1,193 @@ 38-def dynamic_rope_update(rope_forward): 39- """ 40- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE 41- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). 42+def patched_dynamic_rope_update(rope_forward): 43+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]`` 44 45- Args: 46- rope_forward (Callable): 47- The forward pass of the RoPE implementation. 48+ ``rope_type`` is determined in the constructor of class 49+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`. 50 51- Returns: 52- The decorated forward pass. 53+ .. code-block:: python 54+ 55+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None: 56+ self.rope_type = config.rope_scaling.get( 57+ "rope_type", config.rope_scaling.get("type")) 58+ else: 59+ self.rope_type = "default" 60+ 61+ The original code of the patched function: 62+ 63+ .. code-block:: python 64+ 65+ def dynamic_rope_update(rope_forward): 66+ def longrope_frequency_update(self, position_ids, device): 67+ seq_len = torch.max(position_ids) + 1 68+ if hasattr(self.config, "original_max_position_embeddings"): 69+ original_max_position_embeddings = 70+ self.config.original_max_position_embeddings 71+ else: 72+ original_max_position_embeddings = 73+ self.config.max_position_embeddings 74+ if seq_len > original_max_position_embeddings: 75+ if not hasattr(self, "long_inv_freq"): 76+ self.long_inv_freq, _ = self.rope_init_fn( 77+ self.config, device, seq_len=original_max_position_embeddings + 1 78+ ) 79+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) 80+ else: 81+ self.original_inv_freq = self.original_inv_freq.to(device) 82+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 83+ 84+ def dynamic_frequency_update(self, position_ids, device): 85+ seq_len = torch.max(position_ids) + 1 86+ if seq_len > self.max_seq_len_cached: # growth 87+ inv_freq, self.attention_scaling = self.rope_init_fn( 88+ self.config, device, seq_len=seq_len) 89+ self.register_buffer("inv_freq", inv_freq, persistent=False) 90+ self.max_seq_len_cached = seq_len 91+ 92+ if seq_len < self.original_max_seq_len and 93+ self.max_seq_len_cached > self.original_max_seq_len: 94+ self.original_inv_freq = self.original_inv_freq.to(device) 95+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 96+ self.max_seq_len_cached = self.original_max_seq_len 97+ 98+ @wraps(rope_forward) 99+ def wrapper(self, x, position_ids): 100+ if "dynamic" in self.rope_type: 101+ dynamic_frequency_update(self, position_ids, device=x.device) 102+ elif self.rope_type == "longrope": 103+ longrope_frequency_update(self, position_ids, device=x.device) 104+ return rope_forward(self, x, position_ids) 105+ 106+ return wrapper 107+ 108 """ 109 110 def longrope_frequency_update(self, position_ids, device, layer_type=None): 111- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" 112+ # It is no use to patch the function after the model is created 113+ # as rope_init_fn is an attribute set to one function when the model 114+ # is created and when no patch is applied yet. 115+ # So we select the patched version here. 116+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) 117 seq_len = torch.max(position_ids) + 1 118- original_max_position_embeddings = getattr( 119- self.config, "original_max_position_embeddings", self.config.max_position_embeddings 120- ) 121+ if hasattr(self.config, "original_max_position_embeddings"): 122+ original_max_position_embeddings = self.config.original_max_position_embeddings 123+ else: 124+ original_max_position_embeddings = self.config.max_position_embeddings 125+ 126 if layer_type is None: 127- rope_type = self.rope_type 128+ # rope_type = self.rope_type 129 original_inv_freq = self.original_inv_freq 130 prefix = "" 131 else: 132- rope_type = self.rope_type[layer_type] 133+ # rope_type = self.rope_type[layer_type] 134 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 135 prefix = f"{layer_type}_" 136 137- if seq_len > original_max_position_embeddings: 138- if not hasattr(self, f"{layer_type}_long_inv_freq"): 139- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] 140- long_inv_freq, _ = rope_init_fn( 141- self.config, 142- device, 143- seq_len=original_max_position_embeddings + 1, 144- layer_type=layer_type, 145- ) 146- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False) 147- setattr(self, f"{prefix}long_inv_freq", long_inv_freq) 148- else: 149- # This .to() is needed if the model has been moved to a device after being initialized (because 150- # the buffer is automatically moved, but not the original copy) 151- original_inv_freq = original_inv_freq.to(device) 152- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) 153- setattr(self, f"{prefix}original_inv_freq", original_inv_freq) 154+ # At export time, seq_len is unknown. 155+ long_inv_freq, _ = rope_init_fn( 156+ self.config, device, seq_len=original_max_position_embeddings + 1 157+ ) 158+ original_inv_freq = self.original_inv_freq.to(device) 159+ 160+ # PATCHED: uses torch.cond instead of a test 161+ cond = (seq_len > original_max_position_embeddings).item() 162+ inv_freq = torch.cond( 163+ cond, 164+ (lambda x, y: x.clone()), 165+ (lambda x, y: y.clone()), 166+ [long_inv_freq, original_inv_freq], 167+ ) 168+ setattr(self, f"{prefix}inv_freq", inv_freq) 169+ # if seq_len > original_max_position_embeddings: 170+ # self.inv_freq = self.long_inv_freq 171+ # else: 172+ # self.inv_freq = self.original_inv_freq 173 174 def dynamic_frequency_update(self, position_ids, device, layer_type=None): 175- """ 176- dynamic RoPE layers should recompute `inv_freq` in the following situations: 177- 1 - growing beyond the cached sequence length (allow scaling) 178- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 179- """ 180+ # constructor: 181+ # - self.max_seq_len_cached = config.max_position_embeddings 182+ # - self.original_max_seq_len = config.max_position_embeddings 183+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 184+ 185+ # It is no use to patch the function after the model is created 186+ # as rope_init_fn is an attribute set to one function when the model 187+ # is created and when no patch is applied yet. 188+ # So we select the patched version here. 189+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) 190+ 191+ # This behaviour is difficult to translate. 192+ # The sequence always grows. 193+ # The test should always True. 194+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len 195+ # 196+ # if seq_len > self.max_seq_len_cached: # growth 197+ # inv_freq, self.attention_scaling = self.rope_init_fn( 198+ # self.config, device, seq_len=seq_len 199+ # ) 200+ # self.register_buffer("inv_freq", inv_freq, persistent=False) 201+ # self.max_seq_len_cached = seq_len 202+ # 203+ # So we should not need what follows. 204+ # 205+ # cond = (seq_len > self.max_seq_len_cached).item() 206+ # self.attention_scaling = torch.cond( 207+ # cond, 208+ # (lambda x, y: x.clone()), 209+ # (lambda x, y: y.clone()), 210+ # [attention_scaling, self.attention_scaling], 211+ # ) 212+ 213 seq_len = torch.max(position_ids) + 1 214+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len) 215+ 216 if layer_type is None: 217- rope_type = self.rope_type 218- max_seq_len_cached = self.max_seq_len_cached 219+ # rope_type = self.rope_type 220+ # max_seq_len_cached = self.max_seq_len_cached 221 original_inv_freq = self.original_inv_freq 222 prefix = "" 223 else: 224- rope_type = self.rope_type[layer_type] 225- max_seq_len_cached = getattr( 226- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached 227- ) 228+ # rope_type = self.rope_type[layer_type] 229+ # max_seq_len_cached = getattr( 230+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached 231+ # ) 232 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 233 prefix = f"{layer_type}_" 234 235- if seq_len > max_seq_len_cached: # growth 236- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] 237- inv_freq, self.attention_scaling = rope_init_fn( 238- self.config, 239- device, 240- seq_len=seq_len, 241- layer_type=layer_type, 242- ) 243- # TODO joao: may break with compilation 244- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False) 245- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len) 246+ # Second test to translate. 247+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True. 248+ # But in that case the following condition is a way to restore the original cache. 249 250- if ( 251- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len 252- ): # reset 253- # This .to() is needed if the model has been moved to a device after being initialized (because 254- # the buffer is automatically moved, but not the original copy) 255- original_inv_freq = original_inv_freq.to(device) 256- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) 257- setattr(self, f"{prefix}original_inv_freq", original_inv_freq) 258- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len) 259+ # if ( 260+ # seq_len < self.original_max_seq_len 261+ # and self.max_seq_len_cached > self.original_max_seq_len 262+ # ): 263+ # self.original_inv_freq = self.original_inv_freq.to(device) 264+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 265+ # self.max_seq_len_cached = self.original_max_seq_len 266+ 267+ original_inv_freq = self.original_inv_freq.to(device) 268+ cond = (seq_len >= self.original_max_seq_len).item() 269+ # PATCHED: uses torch.cond instead of a test 270+ inv_freq = torch.cond( 271+ cond, 272+ (lambda x, y: x.clone()), 273+ (lambda x, y: y.clone()), 274+ [long_inv_freq, original_inv_freq], 275+ ) 276+ setattr(self, f"{prefix}inv_freq", inv_freq) 277 278 @wraps(rope_forward) 279 def wrapper(self, x, position_ids, layer_type=None): 280- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] 281- kwargs = {"layer_type": layer_type} if layer_type is not None else {} 282- if "dynamic" in rope_type: 283- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs) 284- elif rope_type == "longrope": 285- longrope_frequency_update(self, position_ids, device=x.device, **kwargs) 286- return rope_forward(self, x, position_ids, **kwargs) 287+ if layer_type is None: 288+ if "dynamic" in self.rope_type: 289+ dynamic_frequency_update(self, position_ids, device=x.device) 290+ elif self.rope_type == "longrope": 291+ longrope_frequency_update(self, position_ids, device=x.device) 292+ return rope_forward(self, x, position_ids) 293+ 294+ if "dynamic" in self.rope_type: 295+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) 296+ elif self.rope_type == "longrope": 297+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) 298+ return rope_forward(self, x, position_ids, layer_type=layer_type) 299 300 return wrapper
impacted nodes
aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze aten.unsqueeze.default(unsqueeze, 2) -> unsqueeze_1 aten._assert_tensor_metadata.default(unsqueeze_1) -> _assert_tensor_metadata_default_3 aten.to.dtype(unsqueeze_1, torch.float32) -> to_3 aten.expand.default(to_3, [sym_size_int_19, -1, 1]) -> expand_4 aten._assert_tensor_metadata.default(expand_4) -> _assert_tensor_metadata_default_4 aten.to.dtype_layout(expand_4) -> to_4 aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_1 aten.unsqueeze.default(slice_1, 1) -> unsqueeze_2 aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807) -> slice_2 aten._assert_tensor_metadata.default(slice_2) -> _assert_tensor_metadata_default_5 aten.to.dtype(slice_2, torch.float32) -> to_5 <built-in function getitem>(wrap_with_autocast, 0) -> mul <built-in function getitem>(wrap_with_autocast, 1) -> mul_1 aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8 aten.to.dtype(mul, torch.float32) -> to_8 aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9 aten.to.dtype(mul_1, torch.float32) -> to_9
- append(family: str, function_to_patch: str | Callable, patch: Callable) PatchInfo[source][source]¶
Stores a patch.
- Parameters:
family – a category, anything to classify the patch
function_to_patch – function to patch
patch – function patched
- Returns:
instance of PatchInfo
- make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source][source]¶
Creates a report based on the involved patches.
- Parameters:
patches – from method
patches_involded_in_graph()format – format of the report
- Returns:
report
- matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source][source]¶
Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.
- patches_involded_in_graph(graph: torch.fx.Graph) List[Tuple[PatchInfo, List[torch.fx.Node]]][source][source]¶
Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.
- Parameters:
graph – fx graph
- Returns:
list of nodes impacted by a patch
- class onnx_diagnostic.torch_export_patches.patch_details.PatchInfo(function_to_patch: str | Callable, patch: Callable, family: str = '')[source][source]¶
Stores information about patches.
- Parameters:
function_to_patch – function to patch
patch – function patched
family – a category, anything to classify the patch
- format_diff(format: str = 'raw') str[source][source]¶
Format a diff between two function as a string.
- Parameters:
format –
'raw'or'rst'- Returns:
diff
<<<
import transformers import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched_eager_mask, ) eager_mask = transformers.masking_utils.eager_mask diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst") print(diff)
>>>
eager_mask -> patched_eager_mask¶
1--- original 2+++ rewritten 3@@ -1,4 +1,4 @@ 4-def eager_mask( 5+def patched_eager_mask( 6 batch_size: int, 7 cache_position: torch.Tensor, 8 kv_length: int, 9@@ -8,31 +8,12 @@ 10 dtype: torch.dtype = torch.float32, 11 **kwargs, 12 ) -> torch.Tensor: 13- """ 14- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that 15- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that 16- it should not. 17- 18- Args: 19- batch_size (`int`): 20- The batch size of the input sequence. 21- cache_position (`torch.Tensor`): 22- A tensor of shape (query_length,) indicating the current indices of the input sequence elements. 23- kv_length (`int`): 24- The size that the key and value states will have during the attention computation. 25- kv_offset (`int`, optional): 26- An optional offset to indicate at which first position the key and values states will refer to. 27- mask_function (`Callable`): 28- The mask factory function describing the mask pattern. 29- attention_mask (`torch.Tensor`, optional): 30- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) 31- dtype (`torch.dtype`, optional): 32- The dtype to use for the mask. By default, `torch.float32`. 33- """ 34+ """manual patch for function ``transformers.masking_utils.eager_mask``.""" 35 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf 36 _ = kwargs.pop("allow_is_causal_skip", None) 37 _ = kwargs.pop("allow_is_bidirectional_skip", None) 38- mask = sdpa_mask( 39+ # PATCHED: this line called the patched version of sdpa_mask 40+ mask = patched_sdpa_mask_recent_torch( 41 batch_size=batch_size, 42 cache_position=cache_position, 43 kv_length=kv_length, 44@@ -45,6 +26,10 @@ 45 **kwargs, 46 ) 47 min_dtype = torch.finfo(dtype).min 48- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type) 49- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) 50+ # PATCHED: the following line 51+ # we need 0s where the tokens should be taken into account, 52+ # and -inf otherwise (mask is already of boolean type) 53+ # mask = 54+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) 55+ mask = (~mask).to(dtype) * min_dtype 56 return mask