onnx_diagnostic.torch_export_patches¶
submodules
- onnx_diagnostic.torch_export_patches.eval
- onnx_diagnostic.torch_export_patches.onnx_export_errors
- onnx_diagnostic.torch_export_patches.onnx_export_serialization
- onnx_diagnostic.torch_export_patches.patches
- onnx_diagnostic.torch_export_patches.patch_expressions
- onnx_diagnostic.torch_export_patches.patch_inputs
- onnx_diagnostic.torch_export_patches.patch_module
- onnx_diagnostic.torch_export_patches.patch_module_helper
- onnx_diagnostic.torch_export_patches.register_flattening_functions(verbose: int = 0)[source][source]¶
Registers functions to serialize deserialize cache or other classes implemented in transformers and used as inputs. This is needed whenever a model must be exported through
torch.export.export()
.
- onnx_diagnostic.torch_export_patches.torch_export_rewrite(rewrite: torch.nn.Module | List[Tuple[type, str] | Callable] | None = None, dump_rewriting: str | None = None, verbose: int = 0)[source][source]¶
Automatically rewrite the methods given in rewrite to export control flows (test and loops).
- Parameters:
rewrite – methods of functions to rewrite, if not empty, the function may try to discover them, a method is defined by its class (a type) and its name if the class is local, by itself otherwise, it can also be a model, in that case, the function calls
code_needing_rewriting
to retrieve the necessary rewritingdump_rewriting – dumps rewriting into that folder, if it does not exists, it creates it.
verbose – verbosity, up to 10, 10 shows the rewritten code,
verbose=1
shows the rewritten function,verbose=2
shows the rewritten code as well
Example:
class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: return x + y else: return torch.abs(x) + y + 1 model = Model() x, y = torch.rand((4, 5)), torch.rand((4, 5)) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) with torch_export_rewrite(rewrite=[(Model, "forward")]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
If the method to rewrite is not local, then the following can be used:
with torch_export_rewrite(rewrite=[Model.forward]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
Functions (if not local) can also be rewritten:
def outside(x, y): if x.sum() > 0: return x + y else: return torch.abs(x) + y + 1 class Model(torch.nn.Module): def forward(self, x, y): return outside(x, y) model = Model() x, y = torch.rand((4, 5)), torch.rand((4, 5)) DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) with torch_export_rewrite(rewrite=[outside]): ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
- onnx_diagnostic.torch_export_patches.torch_export_patches(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, catch_constraints: bool = True, stop_if_static: int = 0, verbose: int = 0, patch: bool = True, custom_patches: List[type[torch.nn.Module]] | None = None, rewrite: List[Callable] | None = None, dump_rewriting: str | None = None) Callable [source][source]¶
Tries to bypass some situations
torch.export.export()
does not support. See also Patches Explained and Coverage of the Patches.- Parameters:
patch_sympy – fix missing method
name
for IntegerConstantpatch_torch – patches torch with supported implementation
patch_transformers – patches transformers with supported implementation
catch_constraints – catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable
SKIP_SOLVE_CONSTRAINTS=0
can be put to stop at that stage.stop_if_static – see example Find and fix an export issue due to dynamic shapes, to stop the export as soon as an issue is detected with dynamic shapes and show a stack trace indicating the exact location of the issue,
if stop_if_static > 1
, more methods are replace to catch more issuespatch – if False, disable all patches except the registration of serialization function
custom_patches – to apply custom patches, every patched class must define static attributes
_PATCHES_
,_PATCHED_CLASS_
rewrite – list of methods to automatically rewrite before exporting, methods with control flow need to be rewritten before being exported if the execution path depends on the inputs, this is done by function
transform_method
, its documentation provides possible valuesdump_rewriting – dumps rewriting information in file beginning with that prefix
verbose – to show which patches is applied
The list of available patches.
torch.jit.isinstance
torch._dynamo.mark_static_address
torch._subclasses.fake_impls.infer_size
torch.vmap
fix missing method
name
forsympy.S.IntegerConstant
AttentionMaskConverter._make_causal_mask
Serialization of
MambaCache
(in transformers)Serialization of
DynamicCache
(in transformers)reduce errors due to shape inference
fixes some transformers classes, see
onnx_diagnostic.torch_export_patches.patches.patch_transformers
Serialization issues happen when a module takes one input or output has a type
torch.export.export()
cannot serialize.Examples:
with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = to_onnx(..., inputs, ...)
with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = torch.onnx.export(..., inputs, ...)
It can be used as well to fix the torch export:
with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...)
When running the model through the exported program, only the serialization functions need to be restored:
with register_additional_serialization_functions() as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...)
When exporting a model with a cache, the following error message may appear
AssertionError: Mutating module attribute _seen_tokens during export.
. It can be avoided by settingstrict=False
when calltorch.export.export()
.