onnx_diagnostic.torch_export_patches¶
submodules
- onnx_diagnostic.torch_export_patches.bypass_export_some_errors(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, catch_constraints: bool = True, stop_if_static: bool = False, verbose: int = 0, patch: bool = True) Callable [source]¶
Tries to bypass some situations
torch.export.export()
does not support.- Parameters:
patch_sympy – fix missing method
name
for IntegerConstantpatch_torch – patches torch with supported implementation
patch_transformers – patches transformers with supported implementation
catch_constraints – catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable
SKIP_SOLVE_CONSTRAINTS=0
can be put to stop at that stage.stop_if_static – see example Find and fix an export issue due to dynamic shapes, to stop the export as soon as an issue is detected with dynamic shapes and show a stack trace indicating the exact location of the issue
patch – if False, disable all patches except the registration of serialization function
verbose – to show which patches is applied
The list of available patches.
torch.jit.isinstance
torch._dynamo.mark_static_address
torch._subclasses.fake_impls.infer_size
fix missing method
name
forsympy.S.IntegerConstant
AttentionMaskConverter._make_causal_mask
Serialization of
MambaCache
(in transformers)Serialization of
DynamicCache
(in transformers)reduce errors due to shape inference
fixes some transformers classes, see
onnx_diagnostic.torch_export_patches.patches.patch_transformers
Serialization issues happen when a module takes one input or output has a type
torch.export.export()
cannot serialize.Examples:
with bypass_export_some_errors(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = to_onnx(..., inputs, ...)
with bypass_export_some_errors(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = torch.onnx.export(..., inputs, ...)
It can be used as well to fix the torch export:
with bypass_export_some_errors(patch_transformers=True) as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...)
When running the model through the exported program, only the serialization functions need to be restored:
with register_additional_serialization_functions() as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...)
When exporting a model with a cache, the following error message may appear
AssertionError: Mutating module attribute _seen_tokens during export.
. It can be avoided by settingstrict=False
when calltorch.export.export()
.