experimental_experiment.torch_interpreter.onnx_export_errors

experimental_experiment.torch_interpreter.onnx_export_errors.bypass_export_some_errors(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, replace_dynamic_cache: bool = False, verbose: int = 0) Callable[source]

Tries to bypass some situations torch.export.export() does not support.

Parameters:
  • patch_sympy – fix missing method name for IntegerConstant

  • patch_torch – patches torch with supported implementation

  • patch_transformers – patches transformers with supported implementation

  • replace_dynamic_cache – replaces DynamicCache by a patched class avoiding issues with the dynamic shapes inferences, it should be True with LLM using that class and only during the export

The list of available patches.

  • torch.jit.isinstance

  • torch._dynamo.mark_static_address

  • torch._subclasses.fake_impls.infer_size

  • fix missing method name for sympy.S.IntegerConstant

  • AttentionMaskConverter._make_causal_mask

  • Serialialization of MambaCache (in transformers)

  • Serialialization of DynamicCache (in transformers)

  • reduce errors due to shape inference

  • replaces transformers.cache_utils.DynamicCache with patched_DynamicCache

Serialization issues happen when a module takes one input or output has a type torch.export.export() cannot serialize.

Examples:

with bypass_export_some_errors(
    patch_transformers=True,
    replace_dynamic_cache=True,
) as modificator:
    inputs = modificator(inputs)
    onx = to_onnx(..., inputs, ...)
with bypass_export_some_errors(
    patch_transformers=True,
    replace_dynamic_cache=True,
) as modificator:
    inputs = modificator(inputs)
    onx = torch.onnx.export(..., inputs, ...)

It can be used as well to fix the torch export:

with bypass_export_some_errors(
    patch_transformers=True,
    replace_dynamic_cache=True,
) as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When running the model through the exported program, only the serialization functions need to be restored:

with register_additional_serialization_functions() as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When exporting a model with a cache, the following error message may appear AssertionError: Mutating module attribute _seen_tokens during export.. It can be avoided by setting strict=False when call torch.export.export().

experimental_experiment.torch_interpreter.onnx_export_errors.register_additional_serialization_functions(verbose: int = 0) Callable[source]

The necessary modification to run the fx Graph.

experimental_experiment.torch_interpreter.onnx_export_errors.replacement_before_exporting(args: Any) Any[source]

Does replacements on the given inputs such replacing transformers.cache_utils.DynamicCache by experimental_experiment.torch_interpreter.patches.patched_transformers.patched_DynamicCache.