onnx_diagnostic.torch_export_patches.patch_details¶
- class onnx_diagnostic.torch_export_patches.patch_details.PatchDetails[source][source]¶
This class is used to store patching information. This helps understanding which rewriting was applied to which method of functions. Page Patches Diff contains all the diff for all the implemented patches.
<<<
import torch from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] details = PatchDetails() with torch_export_patches( patch_transformers=True, patch_details=details, patch_torch=False ): ep = torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) ) patches = details.patches_involded_in_graph(ep.graph) report = details.make_report(patches, format="rst") print(report)
>>>
auto/patch_transformers: LlamaRotaryEmbedding.forward -> common_RotaryEmbedding.forward¶
1--- original 2+++ rewritten 3@@ -1,18 +1,26 @@ 4-@torch.no_grad() 5-@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) 6-def forward(self, x, position_ids): 7+@patched_dynamic_rope_update 8+def forward(self, x, position_ids, layer_type=None): 9+ if layer_type is not None: 10+ # transformers>=5.0 11+ inv_freq = getattr(self, f"{layer_type}_inv_freq") 12+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling") 13+ else: 14+ # transformers<5.0 15+ inv_freq = self.inv_freq 16+ attention_scaling = self.attention_scaling 17+ 18 inv_freq_expanded = ( 19- self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) 20+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) 21 ) 22 position_ids_expanded = position_ids[:, None, :].float() 23 24 device_type = ( 25 x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" 26 ) 27- with maybe_autocast(device_type=device_type, enabled=False): # Force float32 28+ with torch.autocast(device_type=device_type, enabled=False): # Force float32 29 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 30 emb = torch.cat((freqs, freqs), dim=-1) 31- cos = emb.cos() * self.attention_scaling 32- sin = emb.sin() * self.attention_scaling 33+ cos = emb.cos() * attention_scaling 34+ sin = emb.sin() * attention_scaling 35 36 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 37 38--- original 39+++ rewritten 40@@ -1,99 +1,193 @@ 41-def dynamic_rope_update(rope_forward): 42- """ 43- Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE 44- (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). 45+def patched_dynamic_rope_update(rope_forward): 46+ """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]`` 47 48- Args: 49- rope_forward (Callable): 50- The forward pass of the RoPE implementation. 51+ ``rope_type`` is determined in the constructor of class 52+ :class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`. 53 54- Returns: 55- The decorated forward pass. 56+ .. code-block:: python 57+ 58+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None: 59+ self.rope_type = config.rope_scaling.get( 60+ "rope_type", config.rope_scaling.get("type")) 61+ else: 62+ self.rope_type = "default" 63+ 64+ The original code of the patched function: 65+ 66+ .. code-block:: python 67+ 68+ def dynamic_rope_update(rope_forward): 69+ def longrope_frequency_update(self, position_ids, device): 70+ seq_len = torch.max(position_ids) + 1 71+ if hasattr(self.config, "original_max_position_embeddings"): 72+ original_max_position_embeddings = 73+ self.config.original_max_position_embeddings 74+ else: 75+ original_max_position_embeddings = 76+ self.config.max_position_embeddings 77+ if seq_len > original_max_position_embeddings: 78+ if not hasattr(self, "long_inv_freq"): 79+ self.long_inv_freq, _ = self.rope_init_fn( 80+ self.config, device, seq_len=original_max_position_embeddings + 1 81+ ) 82+ self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) 83+ else: 84+ self.original_inv_freq = self.original_inv_freq.to(device) 85+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 86+ 87+ def dynamic_frequency_update(self, position_ids, device): 88+ seq_len = torch.max(position_ids) + 1 89+ if seq_len > self.max_seq_len_cached: # growth 90+ inv_freq, self.attention_scaling = self.rope_init_fn( 91+ self.config, device, seq_len=seq_len) 92+ self.register_buffer("inv_freq", inv_freq, persistent=False) 93+ self.max_seq_len_cached = seq_len 94+ 95+ if seq_len < self.original_max_seq_len and 96+ self.max_seq_len_cached > self.original_max_seq_len: 97+ self.original_inv_freq = self.original_inv_freq.to(device) 98+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 99+ self.max_seq_len_cached = self.original_max_seq_len 100+ 101+ @wraps(rope_forward) 102+ def wrapper(self, x, position_ids): 103+ if "dynamic" in self.rope_type: 104+ dynamic_frequency_update(self, position_ids, device=x.device) 105+ elif self.rope_type == "longrope": 106+ longrope_frequency_update(self, position_ids, device=x.device) 107+ return rope_forward(self, x, position_ids) 108+ 109+ return wrapper 110+ 111 """ 112 113 def longrope_frequency_update(self, position_ids, device, layer_type=None): 114- """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" 115+ # It is no use to patch the function after the model is created 116+ # as rope_init_fn is an attribute set to one function when the model 117+ # is created and when no patch is applied yet. 118+ # So we select the patched version here. 119+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) 120 seq_len = torch.max(position_ids) + 1 121- original_max_position_embeddings = getattr( 122- self.config, "original_max_position_embeddings", self.config.max_position_embeddings 123- ) 124+ if hasattr(self.config, "original_max_position_embeddings"): 125+ original_max_position_embeddings = self.config.original_max_position_embeddings 126+ else: 127+ original_max_position_embeddings = self.config.max_position_embeddings 128+ 129 if layer_type is None: 130- rope_type = self.rope_type 131+ # rope_type = self.rope_type 132 original_inv_freq = self.original_inv_freq 133 prefix = "" 134 else: 135- rope_type = self.rope_type[layer_type] 136+ # rope_type = self.rope_type[layer_type] 137 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 138 prefix = f"{layer_type}_" 139 140- if seq_len > original_max_position_embeddings: 141- if not hasattr(self, f"{layer_type}_long_inv_freq"): 142- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] 143- long_inv_freq, _ = rope_init_fn( 144- self.config, 145- device, 146- seq_len=original_max_position_embeddings + 1, 147- layer_type=layer_type, 148- ) 149- self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False) 150- setattr(self, f"{prefix}long_inv_freq", long_inv_freq) 151- else: 152- # This .to() is needed if the model has been moved to a device after being initialized (because 153- # the buffer is automatically moved, but not the original copy) 154- original_inv_freq = original_inv_freq.to(device) 155- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) 156- setattr(self, f"{prefix}original_inv_freq", original_inv_freq) 157+ # At export time, seq_len is unknown. 158+ long_inv_freq, _ = rope_init_fn( 159+ self.config, device, seq_len=original_max_position_embeddings + 1 160+ ) 161+ original_inv_freq = self.original_inv_freq.to(device) 162+ 163+ # PATCHED: uses torch.cond instead of a test 164+ cond = (seq_len > original_max_position_embeddings).item() 165+ inv_freq = torch.cond( 166+ cond, 167+ (lambda x, y: x.clone()), 168+ (lambda x, y: y.clone()), 169+ [long_inv_freq, original_inv_freq], 170+ ) 171+ setattr(self, f"{prefix}inv_freq", inv_freq) 172+ # if seq_len > original_max_position_embeddings: 173+ # self.inv_freq = self.long_inv_freq 174+ # else: 175+ # self.inv_freq = self.original_inv_freq 176 177 def dynamic_frequency_update(self, position_ids, device, layer_type=None): 178- """ 179- dynamic RoPE layers should recompute `inv_freq` in the following situations: 180- 1 - growing beyond the cached sequence length (allow scaling) 181- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) 182- """ 183+ # constructor: 184+ # - self.max_seq_len_cached = config.max_position_embeddings 185+ # - self.original_max_seq_len = config.max_position_embeddings 186+ # - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) 187+ 188+ # It is no use to patch the function after the model is created 189+ # as rope_init_fn is an attribute set to one function when the model 190+ # is created and when no patch is applied yet. 191+ # So we select the patched version here. 192+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type) 193+ 194+ # This behaviour is difficult to translate. 195+ # The sequence always grows. 196+ # The test should always True. 197+ # So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len 198+ # 199+ # if seq_len > self.max_seq_len_cached: # growth 200+ # inv_freq, self.attention_scaling = self.rope_init_fn( 201+ # self.config, device, seq_len=seq_len 202+ # ) 203+ # self.register_buffer("inv_freq", inv_freq, persistent=False) 204+ # self.max_seq_len_cached = seq_len 205+ # 206+ # So we should not need what follows. 207+ # 208+ # cond = (seq_len > self.max_seq_len_cached).item() 209+ # self.attention_scaling = torch.cond( 210+ # cond, 211+ # (lambda x, y: x.clone()), 212+ # (lambda x, y: y.clone()), 213+ # [attention_scaling, self.attention_scaling], 214+ # ) 215+ 216 seq_len = torch.max(position_ids) + 1 217+ long_inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len) 218+ 219 if layer_type is None: 220- rope_type = self.rope_type 221- max_seq_len_cached = self.max_seq_len_cached 222+ # rope_type = self.rope_type 223+ # max_seq_len_cached = self.max_seq_len_cached 224 original_inv_freq = self.original_inv_freq 225 prefix = "" 226 else: 227- rope_type = self.rope_type[layer_type] 228- max_seq_len_cached = getattr( 229- self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached 230- ) 231+ # rope_type = self.rope_type[layer_type] 232+ # max_seq_len_cached = getattr( 233+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached 234+ # ) 235 original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") 236 prefix = f"{layer_type}_" 237 238- if seq_len > max_seq_len_cached: # growth 239- rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] 240- inv_freq, self.attention_scaling = rope_init_fn( 241- self.config, 242- device, 243- seq_len=seq_len, 244- layer_type=layer_type, 245- ) 246- # TODO joao: may break with compilation 247- self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False) 248- setattr(self, f"{layer_type}_max_seq_len_cached", seq_len) 249+ # Second test to translate. 250+ # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True. 251+ # But in that case the following condition is a way to restore the original cache. 252 253- if ( 254- seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len 255- ): # reset 256- # This .to() is needed if the model has been moved to a device after being initialized (because 257- # the buffer is automatically moved, but not the original copy) 258- original_inv_freq = original_inv_freq.to(device) 259- self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) 260- setattr(self, f"{prefix}original_inv_freq", original_inv_freq) 261- setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len) 262+ # if ( 263+ # seq_len < self.original_max_seq_len 264+ # and self.max_seq_len_cached > self.original_max_seq_len 265+ # ): 266+ # self.original_inv_freq = self.original_inv_freq.to(device) 267+ # self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) 268+ # self.max_seq_len_cached = self.original_max_seq_len 269+ 270+ original_inv_freq = self.original_inv_freq.to(device) 271+ cond = (seq_len >= self.original_max_seq_len).item() 272+ # PATCHED: uses torch.cond instead of a test 273+ inv_freq = torch.cond( 274+ cond, 275+ (lambda x, y: x.clone()), 276+ (lambda x, y: y.clone()), 277+ [long_inv_freq, original_inv_freq], 278+ ) 279+ setattr(self, f"{prefix}inv_freq", inv_freq) 280 281 @wraps(rope_forward) 282 def wrapper(self, x, position_ids, layer_type=None): 283- rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] 284- kwargs = {"layer_type": layer_type} if layer_type is not None else {} 285- if "dynamic" in rope_type: 286- dynamic_frequency_update(self, position_ids, device=x.device, **kwargs) 287- elif rope_type == "longrope": 288- longrope_frequency_update(self, position_ids, device=x.device, **kwargs) 289- return rope_forward(self, x, position_ids, **kwargs) 290+ if layer_type is None: 291+ if "dynamic" in self.rope_type: 292+ dynamic_frequency_update(self, position_ids, device=x.device) 293+ elif self.rope_type == "longrope": 294+ longrope_frequency_update(self, position_ids, device=x.device) 295+ return rope_forward(self, x, position_ids) 296+ 297+ if "dynamic" in self.rope_type: 298+ dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) 299+ elif self.rope_type == "longrope": 300+ longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) 301+ return rope_forward(self, x, position_ids, layer_type=layer_type) 302 303 return wrapper
impacted nodes
aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0) -> unsqueeze_12 aten.unsqueeze.default(unsqueeze_12, 2) -> unsqueeze_13 aten._assert_tensor_metadata.default(unsqueeze_13) -> _assert_tensor_metadata_default_3 aten.to.dtype(unsqueeze_13, torch.float32) -> to_3 aten.expand.default(to_3, [sym_size_int_19, -1, 1]) -> expand_1 aten._assert_tensor_metadata.default(expand_1) -> _assert_tensor_metadata_default_4 aten.to.dtype_layout(expand_1) -> to_4 aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807) -> slice_4 aten.unsqueeze.default(slice_4, 1) -> unsqueeze_14 aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807) -> slice_5 aten._assert_tensor_metadata.default(slice_5) -> _assert_tensor_metadata_default_5 aten.to.dtype(slice_5, torch.float32) -> to_5 <built-in function getitem>(wrap_with_autocast, 0) -> mul <built-in function getitem>(wrap_with_autocast, 1) -> mul_1 aten._assert_tensor_metadata.default(mul) -> _assert_tensor_metadata_default_8 aten.to.dtype(mul, torch.float32) -> to_8 aten._assert_tensor_metadata.default(mul_1) -> _assert_tensor_metadata_default_9 aten.to.dtype(mul_1, torch.float32) -> to_9
- append(family: str, function_to_patch: str | Callable, patch: Callable) PatchInfo[source][source]¶
Stores a patch.
- Parameters:
family – a category, anything to classify the patch
function_to_patch – function to patch
patch – function patched
- Returns:
instance of PatchInfo
- make_report(patches: List[Tuple[PatchInfo, List[torch.fx.Node]]], format: str = 'raw') str[source][source]¶
Creates a report based on the involved patches.
- Parameters:
patches – from method
patches_involded_in_graph()format – format of the report
- Returns:
report
- matching_pair(patch: PatchInfo, node: torch.fx.Node) bool[source][source]¶
Last validation for a pair. RotaryEmbedding has many rewriting and they all end up in the same code line.
- patches_involded_in_graph(graph: torch.fx.Graph) List[Tuple[PatchInfo, List[torch.fx.Node]]][source][source]¶
Enumerates all patches impacting a graph. The function goes through the graph node (only the main graph) and looks into the metadata to determine if a listed patch was involved.
- Parameters:
graph – fx graph
- Returns:
list of nodes impacted by a patch
- class onnx_diagnostic.torch_export_patches.patch_details.PatchInfo(function_to_patch: str | Callable, patch: Callable, family: str = '')[source][source]¶
Stores information about patches.
- Parameters:
function_to_patch – function to patch
patch – function patched
family – a category, anything to classify the patch
- format_diff(format: str = 'raw') str[source][source]¶
Format a diff between two function as a string.
- Parameters:
format –
'raw'or'rst'- Returns:
diff
<<<
import transformers import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched_eager_mask, ) eager_mask = transformers.masking_utils.eager_mask diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst") print(diff)
>>>
eager_mask -> patched_eager_mask¶
1--- original 2+++ rewritten 3@@ -1,4 +1,4 @@ 4-def eager_mask( 5+def patched_eager_mask( 6 batch_size: int, 7 cache_position: torch.Tensor, 8 kv_length: int, 9@@ -6,41 +6,14 @@ 10 mask_function: Callable = causal_mask_function, 11 attention_mask: Optional[torch.Tensor] = None, 12 dtype: torch.dtype = torch.float32, 13- allow_is_bidirectional_skip: bool = False, 14- use_vmap: bool = False, 15 **kwargs, 16 ) -> torch.Tensor: 17- """ 18- Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that 19- the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that 20- it should not. 21- 22- Args: 23- batch_size (`int`): 24- The batch size of the input sequence. 25- cache_position (`torch.Tensor`): 26- A tensor of shape (query_length,) indicating the current indices of the input sequence elements. 27- kv_length (`int`): 28- The size that the key and value states will have during the attention computation. 29- kv_offset (`int`, optional): 30- An optional offset to indicate at which first position the key and values states will refer to. 31- mask_function (`Callable`): 32- The mask factory function describing the mask pattern. 33- attention_mask (`torch.Tensor`, optional): 34- The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) 35- dtype (`torch.dtype`, optional): 36- The dtype to use for the mask. By default, `torch.float32`. 37- allow_is_bidirectional_skip (`bool`, optional): 38- Whether to allow to return `None` for the mask under conditions where we do not have to add any bias, 39- i.e. full attention without any padding. Default to `False`. 40- use_vmap (`bool`, optional): 41- Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be 42- index-based (for the cost of speed performance). By default `False`. 43- """ 44+ """manual patch for function ``transformers.masking_utils.eager_mask``.""" 45 # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf 46 _ = kwargs.pop("allow_is_causal_skip", None) 47- _ = kwargs.pop("allow_torch_fix", None) 48- mask = sdpa_mask( 49+ _ = kwargs.pop("allow_is_bidirectional_skip", None) 50+ # PATCHED: this line called the patched version of sdpa_mask 51+ mask = patched_sdpa_mask_recent_torch( 52 batch_size=batch_size, 53 cache_position=cache_position, 54 kv_length=kv_length, 55@@ -48,14 +21,15 @@ 56 mask_function=mask_function, 57 attention_mask=attention_mask, 58 allow_is_causal_skip=False, 59- allow_is_bidirectional_skip=allow_is_bidirectional_skip, 60+ allow_is_bidirectional_skip=False, 61 allow_torch_fix=False, 62- use_vmap=use_vmap, 63 **kwargs, 64 ) 65- # only bidirectional masks can be skipped, otherwise we convert bool -> float 66- if mask is not None: 67- min_dtype = torch.finfo(dtype).min 68- # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type) 69- mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) 70+ min_dtype = torch.finfo(dtype).min 71+ # PATCHED: the following line 72+ # we need 0s where the tokens should be taken into account, 73+ # and -inf otherwise (mask is already of boolean type) 74+ # mask = 75+ # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) 76+ mask = (~mask).to(dtype) * min_dtype 77 return mask