onnx_diagnostic.torch_export_patches¶
submodules
- onnx_diagnostic.torch_export_patches.eval
- onnx_diagnostic.torch_export_patches.onnx_export_errors
- onnx_diagnostic.torch_export_patches.onnx_export_serialization
- onnx_diagnostic.torch_export_patches.patches
- onnx_diagnostic.torch_export_patches.patch_expressions
- onnx_diagnostic.torch_export_patches.patch_inputs
- onnx_diagnostic.torch_export_patches.patch_module
- onnx_diagnostic.torch_export_patches.patch_module_helper
- onnx_diagnostic.torch_export_patches.serialization
- onnx_diagnostic.torch_export_patches.register_flattening_functions(verbose: int = 0)[source][source]¶
- Registers functions to serialize deserialize cache or other classes implemented in transformers and used as inputs. This is needed whenever a model must be exported through - torch.export.export().
- onnx_diagnostic.torch_export_patches.torch_export_rewrite(rewrite: torch.nn.Module | List[Tuple[type, str] | Callable] | None = None, dump_rewriting: str | None = None, verbose: int = 0)[source][source]¶
- Automatically rewrite the methods given in rewrite to export control flows (test and loops). - Parameters:
- rewrite – methods of functions to rewrite, if not empty, the function may try to discover them, a method is defined by its class (a type) and its name if the class is local, by itself otherwise, it can also be a model, in that case, the function calls - code_needing_rewritingto retrieve the necessary rewriting
- dump_rewriting – dumps rewriting into that folder, if it does not exists, it creates it. 
- verbose – verbosity, up to 10, 10 shows the rewritten code, - verbose=1shows the rewritten function,- verbose=2shows the rewritten code as well
 
 - Example: - class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: return x + y else: return torch.abs(x) + y + 1 model = Model() x, y = torch.rand((4, 5)), torch.rand((4, 5)) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) with torch_export_rewrite(rewrite=[(Model, "forward")]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) - If the method to rewrite is not local, then the following can be used: - with torch_export_rewrite(rewrite=[Model.forward]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) - Functions (if not local) can also be rewritten: - def outside(x, y): if x.sum() > 0: return x + y else: return torch.abs(x) + y + 1 class Model(torch.nn.Module): def forward(self, x, y): return outside(x, y) model = Model() x, y = torch.rand((4, 5)), torch.rand((4, 5)) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) with torch_export_rewrite(rewrite=[outside]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds) 
- onnx_diagnostic.torch_export_patches.torch_export_patches(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, patch_diffusers: bool = False, catch_constraints: bool = True, stop_if_static: int = 0, verbose: int = 0, patch: bool = True, custom_patches: List[type[torch.nn.Module]] | None = None, rewrite: List[Callable] | None = None, dump_rewriting: str | None = None) Callable[source][source]¶
- Tries to bypass some situations - torch.export.export()does not support. See also Patches Explained and Coverage of the Patches.- Parameters:
- patch_sympy – fix missing method - namefor IntegerConstant
- patch_torch – patches torch with supported implementation 
- patch_transformers – patches transformers with supported implementation 
- patch_diffusers – patches diffusers with supported implementation 
- catch_constraints – catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable - SKIP_SOLVE_CONSTRAINTS=0can be put to stop at that stage.
- stop_if_static – see example Find and fix an export issue due to dynamic shapes, to stop the export as soon as an issue is detected with dynamic shapes and show a stack trace indicating the exact location of the issue, - if stop_if_static > 1, more methods are replace to catch more issues
- patch – if False, disable all patches but keeps the registration of serialization functions if other patch functions are enabled 
- custom_patches – to apply custom patches, every patched class must define static attributes - _PATCHES_,- _PATCHED_CLASS_
- rewrite – list of methods to automatically rewrite before exporting, methods with control flow need to be rewritten before being exported if the execution path depends on the inputs, this is done by function - transform_method, its documentation provides possible values
- dump_rewriting – dumps rewriting information in file beginning with that prefix 
- verbose – to show which patches is applied 
 
 - The list of available patches. - torch.jit.isinstance
- torch._dynamo.mark_static_address
- torch._subclasses.fake_impls.infer_size
- torch.vmap
- fix missing method - namefor- sympy.S.IntegerConstant
- AttentionMaskConverter._make_causal_mask
- Serialization of - MambaCache(in transformers)
- Serialization of - DynamicCache(in transformers)
- reduce errors due to shape inference 
- fixes some transformers classes, see - onnx_diagnostic.torch_export_patches.patches.patch_transformers
 - Serialization issues happen when a module takes one input or output has a type - torch.export.export()cannot serialize.- Examples: - with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = to_onnx(..., inputs, ...) - with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = torch.onnx.export(..., inputs, ...) - It can be used as well to fix the torch export: - with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...) - When running the model through the exported program, only the serialization functions need to be restored: - with register_additional_serialization_functions() as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...) - When exporting a model with a cache, the following error message may appear - AssertionError: Mutating module attribute _seen_tokens during export.. It can be avoided by setting- strict=Falsewhen call- torch.export.export().