Source code for experimental_experiment.helpers

import inspect
from typing import Any, Callable, Dict, Optional, Union
import numpy as np
from onnx import FunctionProto, GraphProto, ModelProto, load as onnx_load
from onnx.helper import np_dtype_to_tensor_dtype


[docs] def string_type(obj: Any, with_shape: bool = False, with_min_max: bool = False) -> str: """ Displays the types of an object as a string. :param obj: any :param with_shape: displays shapes as well :param with_min_max: displays information about the values :return: str .. runpython:: :showcode: from experimental_experiment.helpers import string_type print(string_type((1, ["r", 6.6]))) """ if obj is None: return "None" if isinstance(obj, tuple): if len(obj) == 1: s = string_type(obj[0], with_shape=with_shape, with_min_max=with_min_max) return f"({s},)" js = ",".join( string_type(o, with_shape=with_shape, with_min_max=with_min_max) for o in obj ) return f"({js})" if isinstance(obj, list): js = ",".join( string_type(o, with_shape=with_shape, with_min_max=with_min_max) for o in obj ) return f"#{len(obj)}[{js}]" if isinstance(obj, set): js = ",".join( string_type(o, with_shape=with_shape, with_min_max=with_min_max) for o in obj ) return f"{{{js}}}" if isinstance(obj, dict): s = ",".join( f"{kv[0]}:{string_type(kv[1],with_shape=with_shape,with_min_max=with_min_max)}" for kv in obj.items() ) return f"dict({s})" if isinstance(obj, np.ndarray): if with_min_max: s = string_type(obj, with_shape=with_shape) return f"{s}[{obj.min()}:{obj.max()}]" i = np_dtype_to_tensor_dtype(obj.dtype) if not with_shape: return f"A{i}r{len(obj.shape)}" return f"A{i}s{'x'.join(map(str, obj.shape))}" import torch if isinstance(obj, torch.export.dynamic_shapes._DerivedDim): return "DerivedDim" if isinstance(obj, torch.export.dynamic_shapes._Dim): return "Dim" if isinstance(obj, torch.SymInt): return "SymInt" if isinstance(obj, torch.SymFloat): return "SymFloat" if isinstance(obj, torch.Tensor): if with_min_max: s = string_type(obj, with_shape=with_shape) return f"{s}[{obj.min()}:{obj.max()}]" from .xbuilder._dtype_helper import torch_dtype_to_onnx_dtype i = torch_dtype_to_onnx_dtype(obj.dtype) if not with_shape: return f"T{i}r{len(obj.shape)}" return f"T{i}s{'x'.join(map(str, obj.shape))}" if isinstance(obj, int): if with_min_max: return f"int[{obj}]" return "int" if isinstance(obj, float): if with_min_max: return f"float[{obj}]" return "float" if isinstance(obj, str): return "str" if type(obj).__name__ == "MambaCache": return "MambaCache" if type(obj).__name__ == "Node" and hasattr(obj, "meta"): # torch.fx.node.Node return f"%{obj.target}" if type(obj).__name__ == "ValueInfoProto": return f"OT{obj.type.tensor_type.elem_type}" if obj.__class__.__name__ == "DynamicCache": kc = string_type(obj.key_cache, with_shape=with_shape, with_min_max=with_min_max) vc = string_type(obj.value_cache, with_shape=with_shape, with_min_max=with_min_max) return f"DynamicCache(key_cache={kc}, DynamicCache(value_cache={vc})" if obj.__class__.__name__ == "BatchFeature": s = string_type(obj.data, with_shape=with_shape, with_min_max=with_min_max) return f"BatchFeature(data={s})" raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
[docs] def string_signature(sig: Any) -> str: """ Displays the signature of a functions. """ def _k(p, kind): for name in dir(p): if getattr(p, name) == kind: return name return repr(kind) text = [" __call__ ("] for p in sig.parameters: pp = sig.parameters[p] kind = repr(pp.kind) t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p if pp.default is not inspect._empty: t = f"{t} = {pp.default!r}" if kind == pp.VAR_POSITIONAL: t = f"*{t}" le = (30 - len(t)) * " " text.append(f" {t}{le}|{_k(pp,kind)}") text.append( f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")" ) return "\n".join(text)
[docs] def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str: """ Displays the signature of a functions if the default if the given value is different from """ if hasattr(f, "__init__") and kwargs is None: fct = f.__init__ kwargs = f.__dict__ name = f.__class__.__name__ else: fct = f name = f.__name__ if kwargs is None: kwargs = {} rows = [] sig = inspect.signature(fct) for p in sig.parameters: pp = sig.parameters[p] d = pp.default if d is inspect._empty: if p in kwargs: v = kwargs[p] rows.append(f"{p}={v!r}") continue v = kwargs.get(p, d) if d != v: rows.append(f"{p}={v!r}") continue atts = ", ".join(rows) return f"{name}({atts})"
[docs] def pretty_onnx(onx: Union[FunctionProto, GraphProto, ModelProto, str]) -> str: """ Displays an onnx prot in a better way. """ if isinstance(onx, str): onx = onnx_load(onx, load_external_data=False) try: from onnx_array_api.plotting.text_plot import onnx_simple_text_plot if isinstance(onx, FunctionProto): return ( f"function: {onx.name}[{onx.domain}]\n" f"{onnx_simple_text_plot(onx, recursive=True)}" ) return onnx_simple_text_plot(onx, recursive=True) except ImportError: from onnx.printer import to_text return to_text(onx)
[docs] def make_hash(obj: Any) -> str: """ Returns a simple hash of ``id(obj)`` in four letter. """ aa = id(obj) % (26**3) return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"