onnx_diagnostic.torch_export_patches

onnx_diagnostic.torch_export_patches.register_flattening_functions(verbose: int = 0)[source][source]

Registers functions to serialize deserialize cache or other classes implemented in transformers and used as inputs. This is needed whenever a model must be exported through torch.export.export().

onnx_diagnostic.torch_export_patches.torch_export_rewrite(rewrite: torch.nn.Module | List[Tuple[type, str] | Callable] | None = None, dump_rewriting: str | None = None, verbose: int = 0)[source][source]

Automatically rewrite the methods given in rewrite to export control flows (test and loops).

Parameters:
  • rewrite – methods of functions to rewrite, if not empty, the function may try to discover them, a method is defined by its class (a type) and its name if the class is local, by itself otherwise, it can also be a model, in that case, the function calls code_needing_rewriting to retrieve the necessary rewriting

  • dump_rewriting – dumps rewriting into that folder, if it does not exists, it creates it.

  • verbose – verbosity, up to 10, 10 shows the rewritten code, verbose=1 shows the rewritten function, verbose=2 shows the rewritten code as well

Example:

class Model(torch.nn.Module):
    def forward(self, x, y):
        if x.sum() > 0:
            return x + y
        else:
            return torch.abs(x) + y + 1

model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[(Model, "forward")]):
    ep = torch.export.export(model, (x, y), dynamic_shapes=ds)

If the method to rewrite is not local, then the following can be used:

with torch_export_rewrite(rewrite=[Model.forward]):
    ep = torch.export.export(model, (x, y), dynamic_shapes=ds)

Functions (if not local) can also be rewritten:

def outside(x, y):
    if x.sum() > 0:
        return x + y
    else:
        return torch.abs(x) + y + 1

class Model(torch.nn.Module):
    def forward(self, x, y):
        return outside(x, y)

model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[outside]):
    ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
onnx_diagnostic.torch_export_patches.torch_export_patches(patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, catch_constraints: bool = True, stop_if_static: int = 0, verbose: int = 0, patch: bool = True, custom_patches: List[type[torch.nn.Module]] | None = None, rewrite: List[Callable] | None = None, dump_rewriting: str | None = None) Callable[source][source]

Tries to bypass some situations torch.export.export() does not support. See also Patches Explained and Coverage of the Patches.

Parameters:
  • patch_sympy – fix missing method name for IntegerConstant

  • patch_torch – patches torch with supported implementation

  • patch_transformers – patches transformers with supported implementation

  • catch_constraints – catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable SKIP_SOLVE_CONSTRAINTS=0 can be put to stop at that stage.

  • stop_if_static – see example Find and fix an export issue due to dynamic shapes, to stop the export as soon as an issue is detected with dynamic shapes and show a stack trace indicating the exact location of the issue, if stop_if_static > 1, more methods are replace to catch more issues

  • patch – if False, disable all patches except the registration of serialization function

  • custom_patches – to apply custom patches, every patched class must define static attributes _PATCHES_, _PATCHED_CLASS_

  • rewrite – list of methods to automatically rewrite before exporting, methods with control flow need to be rewritten before being exported if the execution path depends on the inputs, this is done by function transform_method, its documentation provides possible values

  • dump_rewriting – dumps rewriting information in file beginning with that prefix

  • verbose – to show which patches is applied

The list of available patches.

Serialization issues happen when a module takes one input or output has a type torch.export.export() cannot serialize.

Examples:

with torch_export_patches(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    onx = to_onnx(..., inputs, ...)
with torch_export_patches(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    onx = torch.onnx.export(..., inputs, ...)

It can be used as well to fix the torch export:

with torch_export_patches(patch_transformers=True) as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When running the model through the exported program, only the serialization functions need to be restored:

with register_additional_serialization_functions() as modificator:
    inputs = modificator(inputs)
    ep = torch.export.export(..., inputs, ...)

When exporting a model with a cache, the following error message may appear AssertionError: Mutating module attribute _seen_tokens during export.. It can be avoided by setting strict=False when call torch.export.export().

onnx_diagnostic.torch_export_patches.register_additional_serialization_functions(patch_transformers: bool = False, verbose: int = 0) Callable[source][source]

The necessary modifications to run the fx Graph.