onnx_diagnostic.torch_export_patches

onnx_diagnostic.torch_export_patches.bypass_export_some_errors(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, replace_dynamic_cache: bool = False, catch_constraints: bool = True, verbose: int = 0, patch: bool = True) Callable[source]

Tries to bypass some situations torch.export.export() does not support.

Parameters:
  • patch_sympy – fix missing method name for IntegerConstant

  • patch_torch – patches torch with supported implementation

  • patch_transformers – patches transformers with supported implementation

  • replace_dynamic_cache – replaces DynamicCache by a patched class avoiding issues with the dynamic shapes inferences, it should be True with LLM using that class and only during the export

  • catch_constraints – catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable SKIP_SOLVE_CONSTRAINTS=0 can be put to stop at that stage.

  • patch – if False, disable all patches except the registration of serialization function

The list of available patches.

  • torch.jit.isinstance

  • torch._dynamo.mark_static_address

  • torch._subclasses.fake_impls.infer_size

  • fix missing method name for sympy.S.IntegerConstant

  • AttentionMaskConverter._make_causal_mask

  • Serialization of MambaCache (in transformers)

  • Serialization of DynamicCache (in transformers)

  • reduce errors due to shape inference

  • replaces transformers.cache_utils.DynamicCache with patched_DynamicCache

Serialization issues happen when a module takes one input or output has a type torch.export.export() cannot serialize.

Examples:

with bypass_export_some_errors(
    patch_transformers=True,
    replace_dynamic_cache=True,
) as modificator:
    inputs = modificator(inputs)
    onx = to_onnx(..., inputs, ...)
with bypass_export_some_errors(
    patch_transformers=True,
    replace_dynamic_cache=True,
) as modificator:
    inputs = modificator(inputs)
    onx = torch.onnx.export(..., inputs, ...)

It can be used as well to fix the torch export:

with bypass_export_some_errors(
    patch_transformers=True,
    replace_dynamic_cache=True,
) as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When running the model through the exported program, only the serialization functions need to be restored:

with register_additional_serialization_functions() as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When exporting a model with a cache, the following error message may appear AssertionError: Mutating module attribute _seen_tokens during export.. It can be avoided by setting strict=False when call torch.export.export().

onnx_diagnostic.torch_export_patches.register_additional_serialization_functions(verbose: int = 0, replace_dynamic_cache: bool = False) Callable[source]

The necessary modification to run the fx Graph.