Source code for experimental_experiment.helpers

import inspect
from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
from onnx import FunctionProto, GraphProto, ModelProto, load as onnx_load
from onnx.helper import np_dtype_to_tensor_dtype


[docs] def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: """ Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor. :param tensor_dtype: TensorProto's data_type :return: numpy's data_type """ if tensor_dtype >= 16: raise ValueError( f"Unsupported value for tensor_dtype, " f"numpy does not support onnx type {tensor_dtype}." ) from onnx.helper import tensor_dtype_to_np_dtype as cvt return cvt(tensor_dtype)
[docs] def string_type(obj: Any, with_shape: bool = False, with_min_max: bool = False) -> str: """ Displays the types of an object as a string. :param obj: any :param with_shape: displays shapes as well :param with_min_max: displays information about the values :return: str .. runpython:: :showcode: from experimental_experiment.helpers import string_type print(string_type((1, ["r", 6.6]))) """ if obj is None: return "None" if isinstance(obj, tuple): if len(obj) == 1: s = string_type(obj[0], with_shape=with_shape, with_min_max=with_min_max) return f"({s},)" js = ",".join( string_type(o, with_shape=with_shape, with_min_max=with_min_max) for o in obj ) return f"({js})" if isinstance(obj, list): js = ",".join( string_type(o, with_shape=with_shape, with_min_max=with_min_max) for o in obj ) return f"#{len(obj)}[{js}]" if isinstance(obj, set): js = ",".join( string_type(o, with_shape=with_shape, with_min_max=with_min_max) for o in obj ) return f"{{{js}}}" if isinstance(obj, dict): s = ",".join( f"{kv[0]}:{string_type(kv[1],with_shape=with_shape,with_min_max=with_min_max)}" for kv in obj.items() ) return f"dict({s})" if isinstance(obj, np.ndarray): if with_min_max: s = string_type(obj, with_shape=with_shape) return f"{s}[{obj.min()}:{obj.max()}]" i = np_dtype_to_tensor_dtype(obj.dtype) if not with_shape: return f"A{i}r{len(obj.shape)}" return f"A{i}s{'x'.join(map(str, obj.shape))}" import torch if isinstance(obj, torch.export.dynamic_shapes._DerivedDim): return "DerivedDim" if isinstance(obj, torch.export.dynamic_shapes._Dim): return "Dim" if isinstance(obj, torch.SymInt): return "SymInt" if isinstance(obj, torch.SymFloat): return "SymFloat" if isinstance(obj, torch.Tensor): if with_min_max: s = string_type(obj, with_shape=with_shape) if obj.dtype in {torch.complex64, torch.complex128}: return f"{s}[{obj.abs().min()}:{obj.abs().max()}]" return f"{s}[{obj.min()}:{obj.max()}]" from .xbuilder._dtype_helper import torch_dtype_to_onnx_dtype i = torch_dtype_to_onnx_dtype(obj.dtype) if not with_shape: return f"T{i}r{len(obj.shape)}" return f"T{i}s{'x'.join(map(str, obj.shape))}" if isinstance(obj, int): if with_min_max: return f"int[{obj}]" return "int" if isinstance(obj, float): if with_min_max: return f"float[{obj}]" return "float" if isinstance(obj, str): return "str" if isinstance(obj, slice): return "slice" # others classes if type(obj).__name__ == "MambaCache": c = string_type(obj.conv_states, with_shape=with_shape, with_min_max=with_min_max) d = string_type(obj.ssm_states, with_shape=with_shape, with_min_max=with_min_max) return f"MambaCache(conv_states={c}, ssm_states={d})" if type(obj).__name__ == "Node" and hasattr(obj, "meta"): # torch.fx.node.Node return f"%{obj.target}" if type(obj).__name__ == "ValueInfoProto": return f"OT{obj.type.tensor_type.elem_type}" if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"): kc = string_type(obj.key_cache, with_shape=with_shape, with_min_max=with_min_max) vc = string_type(obj.value_cache, with_shape=with_shape, with_min_max=with_min_max) return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})" if obj.__class__.__name__ == "BatchFeature": s = string_type(obj.data, with_shape=with_shape, with_min_max=with_min_max) return f"BatchFeature(data={s})" if obj.__class__.__name__ == "BatchEncoding": s = string_type(obj.data, with_shape=with_shape, with_min_max=with_min_max) return f"BatchEncoding(data={s})" if obj.__class__.__name__ == "VirtualTensor": return ( f"{obj.__class__.__name__}(name={obj.name!r}, " f"dtype={obj.dtype}, shape={obj.shape})" ) if obj.__class__.__name__ == "_DimHint": return str(obj) if isinstance(obj, torch.nn.Module): return f"{obj.__class__.__name__}(...)" raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
[docs] def string_signature(sig: Any) -> str: """ Displays the signature of a functions. """ def _k(p, kind): for name in dir(p): if getattr(p, name) == kind: return name return repr(kind) text = [" __call__ ("] for p in sig.parameters: pp = sig.parameters[p] kind = repr(pp.kind) t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p if pp.default is not inspect._empty: t = f"{t} = {pp.default!r}" if kind == pp.VAR_POSITIONAL: t = f"*{t}" le = (30 - len(t)) * " " text.append(f" {t}{le}|{_k(pp,kind)}") text.append( f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")" ) return "\n".join(text)
[docs] def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str: """ Displays the signature of a functions if the default if the given value is different from """ if hasattr(f, "__init__") and kwargs is None: fct = f.__init__ kwargs = f.__dict__ name = f.__class__.__name__ else: fct = f name = f.__name__ if kwargs is None: kwargs = {} rows = [] sig = inspect.signature(fct) for p in sig.parameters: pp = sig.parameters[p] d = pp.default if d is inspect._empty: if p in kwargs: v = kwargs[p] rows.append(f"{p}={v!r}") continue v = kwargs.get(p, d) if d != v: rows.append(f"{p}={v!r}") continue atts = ", ".join(rows) return f"{name}({atts})"
[docs] def pretty_onnx(onx: Union[FunctionProto, GraphProto, ModelProto, str]) -> str: """ Displays an onnx prot in a better way. """ assert onx is not None, "onx cannot be None" if isinstance(onx, str): onx = onnx_load(onx, load_external_data=False) assert onx is not None, "onx cannot be None" try: from onnx_array_api.plotting.text_plot import onnx_simple_text_plot if isinstance(onx, FunctionProto): return ( f"function: {onx.name}[{onx.domain}]\n" f"{onnx_simple_text_plot(onx, recursive=True)}" ) return onnx_simple_text_plot(onx, recursive=True) except ImportError: from onnx.printer import to_text return to_text(onx)
[docs] def make_hash(obj: Any) -> str: """ Returns a simple hash of ``id(obj)`` in four letter. """ aa = id(obj) % (26**3) return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"
[docs] def get_onnx_signature(model: ModelProto) -> Tuple[Tuple[str, Any], ...]: """ Produces a tuple of tuples correspinding to the signatures. :param model: model :return: signature """ sig = [] for i in model.graph.input: dt = i.type if dt.HasField("sequence_type"): dst = dt.sequence_type.elem_type tdt = dst.tensor_type el = tdt.elem_type shape = tuple(d.dim_param or d.dim_value for d in tdt.shape.dim) sig.append((i.name, [(i.name, el, shape)])) elif dt.HasField("tensor_type"): el = dt.tensor_type.elem_type shape = tuple(d.dim_param or d.dim_value for d in dt.tensor_type.shape.dim) sig.append((i.name, el, shape)) else: raise AssertionError(f"Unable to interpret dt={dt!r} in {i!r}") return tuple(sig)