onnx_diagnostic.torch_export_patches

onnx_diagnostic.torch_export_patches.bypass_export_some_errors(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, catch_constraints: bool = True, stop_if_static: bool = False, verbose: int = 0, patch: bool = True) Callable[source]

Tries to bypass some situations torch.export.export() does not support.

Parameters:
  • patch_sympy – fix missing method name for IntegerConstant

  • patch_torch – patches torch with supported implementation

  • patch_transformers – patches transformers with supported implementation

  • catch_constraints – catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable SKIP_SOLVE_CONSTRAINTS=0 can be put to stop at that stage.

  • stop_if_static – see example Find and fix an export issue due to dynamic shapes, to stop the export as soon as an issue is detected with dynamic shapes and show a stack trace indicating the exact location of the issue

  • patch – if False, disable all patches except the registration of serialization function

  • verbose – to show which patches is applied

The list of available patches.

Serialization issues happen when a module takes one input or output has a type torch.export.export() cannot serialize.

Examples:

with bypass_export_some_errors(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    onx = to_onnx(..., inputs, ...)
with bypass_export_some_errors(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    onx = torch.onnx.export(..., inputs, ...)

It can be used as well to fix the torch export:

with bypass_export_some_errors(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When running the model through the exported program, only the serialization functions need to be restored:

with register_additional_serialization_functions() as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When exporting a model with a cache, the following error message may appear AssertionError: Mutating module attribute _seen_tokens during export.. It can be avoided by setting strict=False when call torch.export.export().

onnx_diagnostic.torch_export_patches.register_additional_serialization_functions(patch_transformers: bool = False, verbose: int = 0) Callable[source]

The necessary modifications to run the fx Graph.