[docs]defmake_undefined_dimension(i:int)->torch.SymInt:""" Uses for a custom op when a new dimension must be introduced to bypass some verification. The following function creates a dummy output with a dimension based on the content. .. code-block:: python def symbolic_shape(x, y): return torch.empty( x.shape[0], make_undefined_dimension(min(x.shape[1], y[0])), ) """try:ti=int(i)except:# noqa: E722ti=10t=torch.ones((ti*2,))t[:ti]=0res=torch.nonzero(t).shape[0]returnres
def_patched_float_arange(start:torch.Tensor,end:torch.Tensor,step:torch.Tensor)->torch.Tensor:"""Float arange."""returntorch.arange(float(start.item()),float(end.item()),float(step.item()),dtype=start.dtype,device=start.device,)def_patched_float_arange_shape(start,end,step):# Fails because:# Did you accidentally call new_dynamic_size() or item()# more times than you needed to in your fake implementation?# try:# n = math.ceil(((end - start) / step).item())# except: # noqa: E722# n = 10n=10returntorch.empty((make_undefined_dimension(n),),dtype=start.dtype,device=start.device)def_iterate_patched_expressions():glo=globals().copy()fork,_vinglo.items():ifk.startswith("_patched_")andnotk.endswith("_shape"):name=kyieldk[len("_patched_"):],glo[name],glo[f"{name}_shape"]_registered:Set[str]=set()def_register_patched_expression(fct:Callable,fct_shape:Callable,namespace:str,fname:str):schema_str=torch.library.infer_schema(fct,mutates_args=())custom_def=torch.library.CustomOpDef(namespace,fname,schema_str,fct)custom_def.register_kernel("cpu")(fct)custom_def._abstract_fn=fct_shape
[docs]defregister_patched_expressions(namespace:str="patched"):""" Registers as custom ops known expressions failing due to dynamic shapes. .. runpython:: :showcode: import pprint from onnx_diagnostic.torch_export_patches.patch_expressions import ( _iterate_patched_expressions, ) pprint.pprint([name for name, _f, _fsh in _iterate_patched_expressions()]) """forname,f,fshin_iterate_patched_expressions():ifnamenotin_registered:_register_patched_expression(f,fsh,namespace,name)_registered.add(name)
[docs]defpatched_selector(fct:Callable,patched_fct:Callable)->Callable:""" Returns **fct** if the model is being executed or **patched_fct** if it is being exported. """returnpatched_fctifis_torchdynamo_exporting()elsefct
[docs]defpatched_float_arange(start,end,step):"""Patched arange when start, end, step are floats."""ifis_torchdynamo_exporting():returntorch.ops.patched.float_arange(start,end,step)else:returntorch.arange(start,end,step)