Source code for onnx_diagnostic.torch_export_patches.patch_expressions

from typing import Callable, Set
import torch
from ..helpers.torch_helper import is_torchdynamo_exporting


[docs] def make_undefined_dimension(i: int) -> torch.SymInt: """ Uses for a custom op when a new dimension must be introduced to bypass some verification. The following function creates a dummy output with a dimension based on the content. .. code-block:: python def symbolic_shape(x, y): return torch.empty( x.shape[0], make_undefined_dimension(min(x.shape[1], y[0])), ) """ try: ti = int(i) except: # noqa: E722 ti = 10 t = torch.ones((ti * 2,)) t[:ti] = 0 res = torch.nonzero(t).shape[0] return res
def _patched_float_arange( start: torch.Tensor, end: torch.Tensor, step: torch.Tensor ) -> torch.Tensor: """Float arange.""" return torch.arange( float(start.item()), float(end.item()), float(step.item()), dtype=start.dtype, device=start.device, ) def _patched_float_arange_shape(start, end, step): # Fails because: # Did you accidentally call new_dynamic_size() or item() # more times than you needed to in your fake implementation? # try: # n = math.ceil(((end - start) / step).item()) # except: # noqa: E722 # n = 10 n = 10 return torch.empty((make_undefined_dimension(n),), dtype=start.dtype, device=start.device) def _iterate_patched_expressions(): glo = globals().copy() for k, _v in glo.items(): if k.startswith("_patched_") and not k.endswith("_shape"): name = k yield k[len("_patched_") :], glo[name], glo[f"{name}_shape"] _registered: Set[str] = set() def _register_patched_expression( fct: Callable, fct_shape: Callable, namespace: str, fname: str ): schema_str = torch.library.infer_schema(fct, mutates_args=()) custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) custom_def.register_kernel("cpu")(fct) custom_def._abstract_fn = fct_shape
[docs] def register_patched_expressions(namespace: str = "patched"): """ Registers as custom ops known expressions failing due to dynamic shapes. .. runpython:: :showcode: import pprint from onnx_diagnostic.torch_export_patches.patch_expressions import ( _iterate_patched_expressions, ) pprint.pprint([name for name, _f, _fsh in _iterate_patched_expressions()]) """ for name, f, fsh in _iterate_patched_expressions(): if name not in _registered: _register_patched_expression(f, fsh, namespace, name) _registered.add(name)
[docs] def patched_selector(fct: Callable, patched_fct: Callable) -> Callable: """ Returns **fct** if the model is being executed or **patched_fct** if it is being exported. """ return patched_fct if is_torchdynamo_exporting() else fct
[docs] def patched_float_arange(start, end, step): """Patched arange when start, end, step are floats.""" if is_torchdynamo_exporting(): return torch.ops.patched.float_arange(start, end, step) else: return torch.arange(start, end, step)