Source code for experimental_experiment.helpers

import ast
import enum
import functools
import inspect
import sys
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
from onnx import (
    AttributeProto,
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    TensorProto,
    ValueInfoProto,
    load as onnx_load,
)
from onnx.helper import (
    np_dtype_to_tensor_dtype as onnx_np_dtype_to_tensor_dtype,
    tensor_dtype_to_np_dtype as onnx_tensor_dtype_to_np_dtype,
)
from onnx.numpy_helper import from_array as onnx_from_array


[docs] def size_type(dtype: Any) -> int: """Returns the element size for an element type.""" if isinstance(dtype, int): # It is a TensorProto.DATATYPE if dtype in { TensorProto.DOUBLE, TensorProto.INT64, TensorProto.UINT64, TensorProto.COMPLEX64, }: return 8 if dtype in {TensorProto.FLOAT, TensorProto.INT32, TensorProto.UINT32}: return 4 if dtype in { TensorProto.FLOAT16, TensorProto.BFLOAT16, TensorProto.INT16, TensorProto.UINT16, }: return 2 if dtype in {TensorProto.INT8, TensorProto.UINT8, TensorProto.BOOL}: return 1 if dtype in {TensorProto.COMPLEX128}: return 16 raise AssertionError(f"Unable to return the element size for type {dtype}") if dtype == np.float64 or dtype == np.int64: return 8 if dtype == np.float32 or dtype == np.float32: return 4 if dtype == np.float16 or dtype == np.int16: return 2 if dtype == np.int8 or dtype == np.uint8: return 1 if hasattr(np, "uint64"): # it fails on mac if dtype == np.uint64: return 8 if dtype == np.uint32: return 4 if dtype == np.uint16: return 2 import torch if dtype in {torch.float64, torch.int64}: return 8 if dtype in {torch.float32, torch.int32}: return 4 if dtype in {torch.float16, torch.int16, torch.bfloat16}: return 2 if dtype in {torch.int8, torch.uint8, torch.bool}: return 1 if hasattr(torch, "uint64"): # it fails on mac if dtype in {torch.uint64}: return 8 if dtype in {torch.uint32}: return 4 if dtype in {torch.uint16}: return 2 raise AssertionError(f"Unexpected dtype={dtype}")
[docs] def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: """ Converts a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor. :param tensor_dtype: TensorProto's data_type :return: numpy's data_type """ if tensor_dtype >= 16: try: import ml_dtypes # noqa: F401 except ImportError as e: raise ValueError( f"Unsupported value for tensor_dtype, " f"numpy does not support onnx type {tensor_dtype}. " f"ml_dtypes can be used." ) from e mapping = { TensorProto.BFLOAT16: ml_dtypes.bfloat16, TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn, TensorProto.FLOAT8E4M3FNUZ: ml_dtypes.float8_e4m3fnuz, TensorProto.FLOAT8E5M2: ml_dtypes.float8_e5m2, TensorProto.FLOAT8E5M2FNUZ: ml_dtypes.float8_e5m2fnuz, } assert ( tensor_dtype in mapping ), f"Unable to find tensor_dtype={tensor_dtype!r} in mapping={mapping}" return mapping[tensor_dtype] return onnx_tensor_dtype_to_np_dtype(tensor_dtype)
[docs] def string_type( obj: Any, with_shape: bool = False, with_min_max: bool = False, with_device: bool = False, ignore: bool = False, limit: int = 10, ) -> str: """ Displays the types of an object as a string. :param obj: any :param with_shape: displays shapes as well :param with_min_max: displays information about the values :param with_device: display the device :param ignore: if True, just prints the type for unknown types :return: str .. runpython:: :showcode: from experimental_experiment.helpers import string_type print(string_type((1, ["r", 6.6]))) """ if obj is None: return "None" if isinstance(obj, tuple): if len(obj) == 1: s = string_type( obj[0], with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) return f"({s},)" if len(obj) < limit: js = ",".join( string_type( o, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) for o in obj ) return f"({js})" tt = string_type( obj[0], with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"({tt},...)#{len(obj)}[{mini},{maxi}:A[{avg}]]" return f"({tt},...)#{len(obj)}" if with_shape else f"({tt},...)" if isinstance(obj, list): if len(obj) < 10: js = ",".join( string_type( o, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) for o in obj ) return f"#{len(obj)}[{js}]" tt = string_type( obj[0], with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"[{tt},...]#{len(obj)}[{mini},{maxi}:{avg}]" return f"[{tt},...]#{len(obj)}" if with_shape else f"[{tt},...]" if isinstance(obj, set): if len(obj) < 10: js = ",".join( string_type( o, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) for o in obj ) return f"{{{js}}}" if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]" return f"{{...}}#{len(obj)}" if with_shape else "{...}" if isinstance(obj, dict): if len(obj) == 0: return "{}" kws = dict( with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items()) return f"dict({s})" if isinstance(obj, np.ndarray): if with_min_max: s = string_type(obj, with_shape=with_shape) if len(obj.shape) == 0: return f"{s}={obj}" if obj.size == 0: return f"{s}[empty]" n_nan = np.isnan(obj.reshape((-1,))).astype(int).sum() if n_nan > 0: nob = obj.ravel() nob = nob[~np.isnan(nob)] if nob.size == 0: return f"{s}[N{n_nan}nans]" return f"{s}[{nob.min()},{nob.max()}:A{nob.astype(float).mean()}N{n_nan}nans]" return f"{s}[{obj.min()},{obj.max()}:A{obj.astype(float).mean()}]" i = np_dtype_to_tensor_dtype(obj.dtype) if not with_shape: return f"A{i}r{len(obj.shape)}" return f"A{i}s{'x'.join(map(str, obj.shape))}" import torch if isinstance(obj, torch.export.dynamic_shapes._DerivedDim): return "DerivedDim" if isinstance(obj, torch.export.dynamic_shapes._Dim): return "Dim" if isinstance(obj, torch.SymInt): return "SymInt" if isinstance(obj, torch.SymFloat): return "SymFloat" if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor): i = torch_dtype_to_onnx_dtype(obj.dtype) prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" if not with_shape: return f"{prefix}F{i}r{len(obj.shape)}" return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}" if isinstance(obj, torch.Tensor): if with_min_max: s = string_type(obj, with_shape=with_shape, with_device=with_device) if len(obj.shape) == 0: return f"{s}={obj}" if obj.numel() == 0: return f"{s}[empty]" n_nan = obj.reshape((-1,)).isnan().to(int).sum() if n_nan > 0: nob = obj.reshape((-1,)) nob = nob[~nob.isnan()] if obj.dtype in {torch.complex64, torch.complex128}: return ( f"{s}[{nob.abs().min()},{nob.abs().max():A{nob.mean()}N{n_nan}nans}]" ) return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}N{n_nan}nans]" if obj.dtype in {torch.complex64, torch.complex128}: return f"{s}[{obj.abs().min()},{obj.abs().max()}:A{obj.abs().mean()}]" return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}]" i = torch_dtype_to_onnx_dtype(obj.dtype) prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" if not with_shape: return f"{prefix}T{i}r{len(obj.shape)}" return f"{prefix}T{i}s{'x'.join(map(str, obj.shape))}" if isinstance(obj, bool): if with_min_max: return f"bool={obj}" return "bool" if isinstance(obj, int): if with_min_max: return f"int={obj}" return "int" if isinstance(obj, float): if with_min_max: return f"float={obj}" return "float" if isinstance(obj, str): return "str" if isinstance(obj, slice): return "slice" # others classes if type(obj).__name__ == "MambaCache": c = string_type( obj.conv_states, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) d = string_type( obj.ssm_states, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"MambaCache(conv_states={c}, ssm_states={d})" if type(obj).__name__ == "Node" and hasattr(obj, "meta"): # torch.fx.node.Node return f"%{obj.target}" if type(obj).__name__ == "ValueInfoProto": return f"OT{obj.type.tensor_type.elem_type}" if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"): kc = string_type( obj.key_cache, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) vc = string_type( obj.value_cache, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})" if obj.__class__.__name__ == "BatchFeature": s = string_type( obj.data, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"BatchFeature(data={s})" if obj.__class__.__name__ == "BatchEncoding": s = string_type( obj.data, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"BatchEncoding(data={s})" if obj.__class__.__name__ == "VirtualTensor": return ( f"{obj.__class__.__name__}(name={obj.name!r}, " f"dtype={obj.dtype}, shape={obj.shape})" ) if obj.__class__.__name__ == "_DimHint": return str(obj) if isinstance(obj, torch.nn.Module): return f"{obj.__class__.__name__}(...)" if isinstance(obj, torch.dtype): return f"{obj.__class__.__name__}({obj})" if isinstance(obj, torch.utils._pytree.TreeSpec): return repr(obj).replace(" ", "").replace("\n", " ") if ignore: return f"{obj.__class__.__name__}(...)" raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
[docs] def string_signature(sig: Any) -> str: """ Displays the signature of a functions. """ def _k(p, kind): for name in dir(p): if getattr(p, name) == kind: return name return repr(kind) text = [" __call__ ("] for p in sig.parameters: pp = sig.parameters[p] kind = repr(pp.kind) t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p if pp.default is not inspect._empty: t = f"{t} = {pp.default!r}" if kind == pp.VAR_POSITIONAL: t = f"*{t}" le = (30 - len(t)) * " " text.append(f" {t}{le}|{_k(pp,kind)}") text.append( f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")" ) return "\n".join(text)
[docs] def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str: """ Displays the signature of a functions if the default if the given value is different from """ if hasattr(f, "__init__") and kwargs is None: fct = f.__init__ kwargs = f.__dict__ name = f.__class__.__name__ else: fct = f name = f.__name__ if kwargs is None: kwargs = {} rows = [] sig = inspect.signature(fct) for p in sig.parameters: pp = sig.parameters[p] d = pp.default if d is inspect._empty: if p in kwargs: v = kwargs[p] rows.append( f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}" ) continue v = kwargs.get(p, d) if d != v: rows.append(f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}") continue atts = ", ".join(rows) return f"{name}({atts})"
[docs] @functools.cache def onnx_dtype_name(itype: int) -> str: """Returns the ONNX name for a specific element type.""" for k in dir(TensorProto): v = getattr(TensorProto, k) if v == itype: return k raise ValueError(f"Unexpected value itype: {itype}")
[docs] def pretty_onnx( onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str], with_attributes: bool = False, highlight: Optional[Set[str]] = None, ) -> str: """ Displays an onnx prot in a better way. :param with_attributes: displays attributes as well, if only a node is printed :param highlight: to highlight some names :return: text """ assert onx is not None, "onx cannot be None" if isinstance(onx, str): onx = onnx_load(onx, load_external_data=False) assert onx is not None, "onx cannot be None" if isinstance(onx, ValueInfoProto): name = onx.name itype = onx.type.tensor_type.elem_type shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim) shape_str = ",".join(map(str, shape)) return f"{onnx_dtype_name(itype)}[{shape_str}] {name}" if isinstance(onx, AttributeProto): att = onx if att.type == AttributeProto.INT: return f"{att.name}={att.i}" if att.type == AttributeProto.INTS: return f"{att.name}={att.ints}" if att.type == AttributeProto.FLOAT: return f"{att.name}={att.f}" if att.type == AttributeProto.FLOATS: return f"{att.name}={att.floats}" if att.type == AttributeProto.STRING: return f"{att.name}={att.s!r}" if att.type == AttributeProto.TENSOR: from .reference import to_array_extended v = to_array_extended(att.t) vf = v.reshape((-1,)) if vf.size < 10: tt = f"[{', '.join(map(str, vf))}]" else: tt = f"[{', '.join(map(str, vf[:10]))}, ...]" if len(v.shape) != 1: return f"{att.name}=tensor({tt}, dtype={v.dtype}).reshape({v.shape})" return f"{att.name}=tensor({tt}, dtype={v.dtype})" raise NotImplementedError( f"pretty_onnx not implemented yet for AttributeProto={att!r}" ) if isinstance(onx, NodeProto): def _high(n): if highlight and n in highlight: return f"**{n}**" return n text = ( f"{onx.op_type}({', '.join(map(_high, onx.input))})" f" -> {', '.join(map(_high, onx.output))}" ) if onx.domain: text = f"{onx.domain}.{text}" if not with_attributes or not onx.attribute: return text rows = [] for att in onx.attribute: rows.append(pretty_onnx(att)) if len(rows) > 1: suffix = "\n".join(f" {s}" for s in rows) return f"{text}\n{suffix}" return f"{text} --- {rows[0]}" if isinstance(onx, TensorProto): shape = "x".join(map(str, onx.dims)) return f"TensorProto:{onx.data_type}:{shape}:{onx.name}" try: from onnx_array_api.plotting.text_plot import onnx_simple_text_plot if isinstance(onx, FunctionProto): return ( f"function: {onx.name}[{onx.domain}]\n" f"{onnx_simple_text_plot(onx, recursive=True)}" ) return onnx_simple_text_plot(onx, recursive=True) except ImportError: from onnx.printer import to_text return to_text(onx)
[docs] def make_hash(obj: Any) -> str: """ Returns a simple hash of ``id(obj)`` in four letter. """ aa = id(obj) % (26**3) return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"
[docs] def get_onnx_signature(model: ModelProto) -> Tuple[Tuple[str, Any], ...]: """ Produces a tuple of tuples correspinding to the signatures. :param model: model :return: signature """ sig = [] for i in model.graph.input: dt = i.type if dt.HasField("sequence_type"): dst = dt.sequence_type.elem_type tdt = dst.tensor_type el = tdt.elem_type shape = tuple(d.dim_param or d.dim_value for d in tdt.shape.dim) sig.append((i.name, [(i.name, el, shape)])) elif dt.HasField("tensor_type"): el = dt.tensor_type.elem_type shape = tuple(d.dim_param or d.dim_value for d in dt.tensor_type.shape.dim) sig.append((i.name, el, shape)) else: raise AssertionError(f"Unable to interpret dt={dt!r} in {i!r}") return tuple(sig)
[docs] def convert_endian(tensor: TensorProto) -> None: """Call to convert endianness of raw data in tensor. Args: tensor: TensorProto to be converted. """ tensor_dtype = tensor.data_type np_dtype = tensor_dtype_to_np_dtype(tensor_dtype) tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
[docs] def from_array_ml_dtypes(arr: np.ndarray, name: Optional[str] = None) -> TensorProto: """ Converts a numpy array to a tensor def assuming the dtype is defined in ml_dtypes. Args: arr: a numpy array. name: (optional) the name of the tensor. Returns: TensorProto: the converted tensor def. """ import ml_dtypes assert isinstance(arr, np.ndarray), f"arr must be of type np.ndarray, got {type(arr)}" tensor = TensorProto() tensor.dims.extend(arr.shape) if name: tensor.name = name if arr.dtype == ml_dtypes.bfloat16: dtype = TensorProto.BFLOAT16 elif arr.dtype == ml_dtypes.float8_e4m3fn: dtype = TensorProto.FLOAT8E4M3FN elif arr.dtype == ml_dtypes.float8_e4m3fnuz: dtype = TensorProto.FLOAT8E4M3FNUZ elif arr.dtype == ml_dtypes.float8_e5m2: dtype = TensorProto.FLOAT8E5M2 elif arr.dtype == ml_dtypes.float8_e5m2fnuz: dtype = TensorProto.FLOAT8E5M2FNUZ else: raise NotImplementedError(f"No conversion from {arr.dtype}") tensor.data_type = dtype tensor.raw_data = arr.tobytes() # note: tobytes() is only after 1.9. if sys.byteorder == "big": convert_endian(tensor) return tensor
[docs] def from_array_extended(tensor: np.ndarray, name: Optional[str] = None) -> TensorProto: """ Converts an array into a TensorProto. :param tensor: numpy array :param name: name :return: TensorProto """ from onnx.reference.ops.op_cast import ( bfloat16, float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz, ) dt = tensor.dtype if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn": to = TensorProto.FLOAT8E4M3FN dt_to = np.uint8 elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz": to = TensorProto.FLOAT8E4M3FNUZ dt_to = np.uint8 elif dt == float8e5m2 and dt.descr[0][0] == "e5m2": to = TensorProto.FLOAT8E5M2 dt_to = np.uint8 elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz": to = TensorProto.FLOAT8E5M2FNUZ dt_to = np.uint8 elif dt == bfloat16 and dt.descr[0][0] == "bfloat16": to = TensorProto.BFLOAT16 dt_to = np.uint16 else: try: import ml_dtypes except ImportError: ml_dtypes = None if ml_dtypes is not None and ( tensor.dtype == ml_dtypes.bfloat16 or tensor.dtype == ml_dtypes.float8_e4m3fn or tensor.dtype == ml_dtypes.float8_e4m3fnuz or tensor.dtype == ml_dtypes.float8_e5m2 or tensor.dtype == ml_dtypes.float8_e5m2fnuz ): return from_array_ml_dtypes(tensor, name) return onnx_from_array(tensor, name) t = onnx_from_array(tensor.astype(dt_to), name) t.data_type = to return t
[docs] def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821 """ Converts an onnx type into a torch dtype. :param to: onnx dtype :return: torch dtype """ import torch if itype == TensorProto.FLOAT: return torch.float32 if itype == TensorProto.FLOAT16: return torch.float16 if itype == TensorProto.BFLOAT16: return torch.bfloat16 if itype == TensorProto.DOUBLE: return torch.float64 if itype == TensorProto.INT32: return torch.int32 if itype == TensorProto.INT64: return torch.int64 if itype == TensorProto.UINT32: return torch.uint32 if itype == TensorProto.UINT64: return torch.uint64 if itype == TensorProto.BOOL: return torch.bool if itype == TensorProto.INT16: return torch.int16 if itype == TensorProto.UINT16: return torch.uint16 if itype == TensorProto.INT8: return torch.int16 if itype == TensorProto.UINT8: return torch.uint16 if itype == TensorProto.COMPLEX64: return torch.complex64 if itype == TensorProto.COMPLEX128: return torch.complex128 raise NotImplementedError(f"Unable to convert onnx type {itype} to torch.type.")
[docs] def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821 """ Converts a torch dtype into a onnx element type. :param to: torch dtype :return: onnx type """ import torch if to == torch.float32: return TensorProto.FLOAT if to == torch.float16: return TensorProto.FLOAT16 if to == torch.bfloat16: return TensorProto.BFLOAT16 if to == torch.float64: return TensorProto.DOUBLE if to == torch.int64: return TensorProto.INT64 if to == torch.int32: return TensorProto.INT32 if to == torch.bool: return TensorProto.BOOL if to == torch.SymInt: return TensorProto.INT64 if to == torch.SymFloat: return TensorProto.FLOAT if to == torch.complex64: return TensorProto.COMPLEX64 if to == torch.complex128: return TensorProto.COMPLEX128 raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
[docs] def dtype_to_tensor_dtype(dt: "dtype") -> int: # noqa: F821 """ Converts a torch dtype or numpy dtype into a onnx element type. :param to: dtype :return: onnx type """ try: return np_dtype_to_tensor_dtype(dt) except (KeyError, TypeError, ValueError): pass return torch_dtype_to_onnx_dtype(dt)
[docs] def np_dtype_to_tensor_dtype(dt: "dtype") -> int: # noqa: F821 """ Converts a tnumpy dtype into a onnx element type. :param to: dtype :return: onnx type """ try: return onnx_np_dtype_to_tensor_dtype(dt) except ValueError: try: import ml_dtypes except ImportError: ml_dtypes = None if ml_dtypes is not None: if dt == ml_dtypes.bfloat16: return TensorProto.BFLOAT16 if dt == ml_dtypes.float8_e4m3fn: return TensorProto.FLOAT8E4M3FN if dt == ml_dtypes.float8_e4m3fnuz: return TensorProto.FLOAT8E4M3FNUZ if dt == ml_dtypes.float8_e5m2: return TensorProto.FLOAT8E5M2 if dt == ml_dtypes.float8_e5m2fnuz: return TensorProto.FLOAT8E5M2FNUZ raise ValueError(f"Unable to convert type {dt}")
[docs] def rename_dynamic_dimensions( constraints: Dict[str, Set[str]], original: Set[str], ban_prefix: str = "DYN" ) -> Dict[str, str]: """ Renames dynamic shapes as requested by the user. :func:`torch.export.export` uses many names for dynamic dimensions. When building the onnx model, some of them are redundant and can be replaced by the name provided by the user. :param constraints: exhaustive list of used name and all the values equal to it :param original: the names to use if possible :param ban_prefix: avoid any rewriting by a constant starting with this prefix :return: replacement dictionary """ replacements = {s: s for s in original} all_values = set(constraints) | original not_done = set(constraints) max_iter = len(replacements) while not_done and max_iter > 0: max_iter -= 1 for k, v in constraints.items(): common = v & original if not common: continue common = sorted(common) by = common[0] if ban_prefix and by.startswith(ban_prefix): continue replacements[k] = by for vv in v: if vv not in replacements: replacements[vv] = by not_done = all_values - set(replacements) return replacements
[docs] def rename_dynamic_expression(expression: str, replacements: Dict[str, str]): """ Renames variables of an expression. :param expression: something like ``s15 + seq_length`` :param replacements: replacements to make :return: new string """ class RenameVariable(ast.NodeTransformer): def visit_Name(self, node): if node.id in replacements: node.id = replacements[node.id] return node tree = ast.parse(expression) transformer = RenameVariable() new_tree = transformer.visit(tree) return ast.unparse(new_tree)
[docs] def flatten_object(x: Any, drop_keys: bool = False) -> List[Any]: """ Flattens the object. It accepts some common classes used in deep learning. :param x: any object :param drop_keys: drop the keys if a dictionary is flattened. Keeps the order defined by the dictionary if False, sort them if True. :return: flattened object """ if x is None: return x if isinstance(x, (list, tuple)): res = [] for i in x: if i is None or hasattr(i, "shape") or isinstance(i, (int, float, str)): res.append(i) else: res.extend(flatten_object(i, drop_keys=drop_keys)) return tuple(res) if isinstance(x, tuple) else res if isinstance(x, dict): # We flatten the keys. if drop_keys: return flatten_object(list(x.values()), drop_keys=drop_keys) return flatten_object(list(x.items()), drop_keys=drop_keys) if x.__class__.__name__ == "DynamicCache": res = flatten_object(x.key_cache) + flatten_object(x.value_cache) return tuple(res) if x.__class__.__name__ == "MambaCache": return tuple(x.conv_states, x.ssm_states) if hasattr(x, "to_tuple"): return flatten_object(x.to_tuple(), drop_keys=drop_keys) if hasattr(x, "shape"): # A tensor. Nothing to do. return x raise TypeError( f"Unexpected type {type(x)} for x, drop_keys={drop_keys}, " f"content is {string_type(x, with_shape=True)}" )
[docs] def max_diff( expected: Any, got: Any, verbose: int = 0, level: int = 0, flatten: bool = False, debug_info: Optional[List[str]] = None, begin: int = 0, end: int = -1, _index: int = 0, allow_unique_tensor_with_list_of_one_element: bool = True, ) -> Dict[str, float]: """ Returns the maximum discrepancy. :param expected: expected values :param got: values :param verbose: verbosity level :param level: for embedded outputs, used for debug purpposes :param flatten: flatten outputs :param debug_info: debug information :param begin: first output to considered :param end: last output to considered (-1 for the last one) :param _index: used with begin and end :param allow_unique_tensor_with_list_of_one_element: allow a comparison between a single tensor and a list of one tensor :return: dictionary with many values * abs: max abolute error * rel: max relative error * sum: sum of the errors * n: number of outputs values, if there is one output, this number will be the number of elements of this output * dnan: difference in the number of nan You may use :func:`string_diff` to display the discrepancies in one string. """ if expected is None and got is None: return dict(abs=0, rel=0, sum=0, n=0, dnan=0) if allow_unique_tensor_with_list_of_one_element: if hasattr(expected, "shape") and isinstance(got, (list, tuple)) and len(got) == 1: return max_diff( expected, got[0], verbose=verbose, level=level, flatten=False, debug_info=debug_info, allow_unique_tensor_with_list_of_one_element=False, ) return max_diff( expected, got, verbose=verbose, level=level, flatten=flatten, debug_info=debug_info, begin=begin, end=end, _index=_index, allow_unique_tensor_with_list_of_one_element=False, ) if hasattr(expected, "to_tuple"): if verbose >= 6: print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}") return max_diff( expected.to_tuple(), got, verbose=verbose, level=level + 1, debug_info=( [*(debug_info if debug_info else []), f"{' ' * level}to_tupleA"] if verbose > 5 else None ), begin=begin, end=end, _index=_index, flatten=flatten, ) if hasattr(got, "to_tuple"): if verbose >= 6: print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}") return max_diff( expected, got.to_tuple(), verbose=verbose, level=level + 1, debug_info=( [*(debug_info if debug_info else []), f"{' ' * level}to_tupleB"] if verbose > 5 else None ), begin=begin, end=end, _index=_index, flatten=flatten, ) if isinstance(got, (list, tuple)): if len(got) != 1: if verbose >= 6: print( f"[max_diff] list,tuple,2: {string_type(expected)} " f"? {string_type(got)}" ) if verbose > 2: import torch print( f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, " f"len(got)={len(got)}, level={level}, _index={_index}" ) for i, (a, b) in enumerate(zip(expected, got)): if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): print( f" i={i} expected {a.dtype}:{a.shape}, " f"has {b.dtype}:{b.shape}, _index={_index}" ) else: print( f" i={i} a is {type(a)}, " f"b is {type(b)}, _index={_index}" ) return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if verbose >= 6: print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}") return max_diff( expected, got[0], verbose=verbose, level=level + 1, begin=begin, end=end, _index=_index, debug_info=debug_info, flatten=flatten, ) if isinstance(expected, (tuple, list)): if verbose >= 6: print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}") if len(expected) == 1 and not isinstance(got, type(expected)): if verbose >= 6: print(f"[max_diff] list,tuple,3: {string_type(expected)} ? {string_type(got)}") return max_diff( expected[0], got, verbose=verbose, level=level + 1, begin=begin, end=end, _index=_index, debug_info=debug_info, flatten=flatten, ) if not isinstance(got, (tuple, list)): if verbose >= 6: print(f"[max_diff] list,tuple,4: {string_type(expected)} ? {string_type(got)}") if verbose > 2: print( f"[max_diff] inf because type(expected)={type(expected)}, " f"type(got)={type(got)}, level={level}, _index={_index}" ) return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if len(got) != len(expected): if flatten: if verbose >= 6: print( f"[max_diff] list,tuple,5: {string_type(expected)} " f"? {string_type(got)}" ) # Let's flatten. if verbose > 2: print( f"[max_diff] flattening because of length mismatch, " f"expected is {string_type(expected)} and got is {string_type(got)}" ) return max_diff( flatten_object(expected, drop_keys=True), flatten_object(got, drop_keys=True), verbose=verbose, level=level, begin=begin, end=end, _index=_index, debug_info=( [ *(debug_info if debug_info else []), ( f"{' ' * level}flatten[" f"{string_type(expected)},{string_type(got)}]" ), ] if verbose > 5 else None ), flatten=flatten, ) if verbose > 2: import torch print( f"[max_diff] (b) inf because len(expected)={len(expected)}, " f"len(got)={len(got)}, level={level}, _index={_index}" ) for i, (a, b) in enumerate(zip(expected, got)): if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): print( f" i={i} expected {a.dtype}:{a.shape}, " f"has {b.dtype}:{b.shape}, _index={_index}" ) else: print(f" i={i} a is {type(a)}, b is {type(b)}") return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if verbose >= 6: print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}") am, rm, sm, n, dn = 0, 0, 0.0, 0.0, 0 for ip, (e, g) in enumerate(zip(expected, got)): d = max_diff( e, g, verbose=verbose, level=level + 1, debug_info=( [ *(debug_info if debug_info else []), f"{' ' * level}[{ip}] so far abs {am} - rel {rm}", ] if verbose > 5 else None ), begin=begin, end=end, _index=_index + ip, flatten=flatten, ) am = max(am, d["abs"]) dn = max(dn, d["dnan"]) rm = max(rm, d["rel"]) sm += d["sum"] n += d["n"] return dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn) if isinstance(expected, dict): if verbose >= 6: print(f"[max_diff] dict: {string_type(expected)} ? {string_type(got)}") assert ( begin == 0 and end == -1 ), f"begin={begin}, end={end} not compatible with dictionaries" if isinstance(got, dict): if len(expected) != len(got): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if set(expected) != set(got): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) keys = sorted(expected) return max_diff( [expected[k] for k in keys], [got[k] for k in keys], level=level, flatten=flatten, debug_info=debug_info, begin=begin, end=end, _index=_index, verbose=verbose, ) if not isinstance(got, (tuple, list)): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if len(expected) != len(got): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) return max_diff( list(expected.values()), got, level=level, flatten=flatten, debug_info=debug_info, begin=begin, end=end, _index=_index, verbose=verbose, ) import torch if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray): if isinstance(expected, torch.Tensor): expected = expected.detach().cpu().numpy() if isinstance(got, torch.Tensor): got = got.detach().cpu().numpy() if verbose >= 6: print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}") if _index < begin or (end != -1 and _index >= end): # out of boundary return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) if isinstance(expected, (int, float)): if isinstance(got, np.ndarray) and len(got.shape) == 0: got = float(got) if isinstance(got, (int, float)): if expected == got: return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) return dict( abs=abs(expected - got), rel=abs(expected - got) / (abs(expected) + 1e-5), sum=abs(expected - got), n=1, dnan=0, ) return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if expected.dtype in (np.complex64, np.complex128): if got.dtype == expected.dtype: got = np.real(got) elif got.dtype not in (np.float32, np.float64): if verbose >= 10: # To understand the value it comes from. if debug_info: print("\n".join(debug_info)) print( f"[max_diff-c] expected.dtype={expected.dtype}, " f"got.dtype={got.dtype}" ) return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) expected = np.real(expected) if expected.shape != got.shape: if verbose >= 10: # To understand the value it comes from. if debug_info: print("\n".join(debug_info)) print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}") return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) # nan are replace by 1e10, any discrepancies in that order of magnitude # is likely caused by nans exp_cpu = np.nan_to_num(expected.astype(np.float64), nan=1e10) got_cpu = np.nan_to_num(got.astype(np.float64), nan=1e10) diff = np.abs(got_cpu - exp_cpu) ndiff = np.abs(np.isnan(expected).astype(int) - np.isnan(got).astype(int)) rdiff = diff / (np.abs(exp_cpu) + 1e-3) if diff.size == 0: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( (0, 0, 0, 0, 0) if exp_cpu.size == got_cpu.size else (np.inf, np.inf, np.inf, 0, np.inf) ) else: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( float(diff.max()), float(rdiff.max()), float(diff.sum()), float(diff.size), float(ndiff.sum()), ) if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10): # To understand the value it comes from. if debug_info: print("\n".join(debug_info)) print( f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, " f"nan_diff={nan_diff}, dtype={expected.dtype}, " f"shape={expected.shape}, level={level}, _index={_index}" ) if abs_diff >= 10: idiff = np.argmax(diff.reshape((-1,))) x = expected.reshape((-1,))[idiff] y = got.reshape((-1,))[idiff] print( f" [max_diff-2] abs diff={abs_diff}, " f"x={x}, y={y}, level={level}, " f"_index={_index}" ) print(y) if rel_diff >= 10: idiff = np.argmax(rdiff.reshape((-1,))) x = expected.reshape((-1,))[idiff] y = got.reshape((-1,))[idiff] print( f" [max_diff-3] rel diff={rel_diff}, " f"x={x}, y={y}, level={level}, " f"_index={_index}" ) return dict(abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff) if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor): if verbose >= 6: print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}") if _index < begin or (end != -1 and _index >= end): # out of boundary return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) if expected.dtype in (torch.complex64, torch.complex128): if got.dtype == expected.dtype: got = torch.view_as_real(got) elif got.dtype not in (torch.float32, torch.float64): if verbose >= 10: # To understand the value it comes from. if debug_info: print("\n".join(debug_info)) print( f"[max_diff-c] expected.dtype={expected.dtype}, " f"got.dtype={got.dtype}" ) return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) expected = torch.view_as_real(expected) if expected.shape != got.shape: if verbose >= 10: # To understand the value it comes from. if debug_info: print("\n".join(debug_info)) print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}") return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) # nan are replace by 1e10, any discrepancies in that order of magnitude # is likely caused by nans exp_cpu = expected.to(torch.float64).cpu().nan_to_num(1e10) got_cpu = got.to(torch.float64).cpu().nan_to_num(1e10) diff = (got_cpu - exp_cpu).abs() ndiff = (expected.isnan().cpu().to(int) - got.isnan().cpu().to(int)).abs() rdiff = diff / (exp_cpu.abs() + 1e-3) if diff.numel() > 0: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( float(diff.max()), float(rdiff.max()), float(diff.sum()), float(diff.numel()), float(ndiff.sum()), ) elif got_cpu.numel() == exp_cpu.numel(): abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (0.0, 0.0, 0.0, 0.0, 0.0) else: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( np.inf, np.inf, np.inf, np.inf, np.inf, ) if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10): # To understand the value it comes from. if debug_info: print("\n".join(debug_info)) print( f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, " f"nan_diff={nan_diff}, dtype={expected.dtype}, " f"shape={expected.shape}, level={level}, _index={_index}" ) if abs_diff >= 10: idiff = torch.argmax(diff.reshape((-1,))) x = expected.reshape((-1,))[idiff] y = got.reshape((-1,))[idiff] print( f" [max_diff-2] abs diff={abs_diff}, " f"x={x}, y={y}, level={level}, " f"_index={_index}" ) print(y) if rel_diff >= 10: idiff = torch.argmax(rdiff.reshape((-1,))) x = expected.reshape((-1,))[idiff] y = got.reshape((-1,))[idiff] print( f" [max_diff-3] rel diff={rel_diff}, " f"x={x}, y={y}, level={level}, " f"_index={_index}" ) return dict(abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff) if "SquashedNormal" in expected.__class__.__name__: if verbose >= 6: print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}") values = ( expected.mean.detach().to("cpu"), expected.scale.detach().to("cpu"), ) return max_diff( values, got, verbose=verbose, level=level + 1, begin=begin, end=end, _index=_index, flatten=flatten, ) if expected.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"): if got.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"): if verbose >= 6: print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}") return max_diff( [expected.key_cache, expected.value_cache], [got.key_cache, got.value_cache], verbose=verbose, ) if isinstance(got, tuple) and len(got) == 2: return max_diff( [expected.key_cache, expected.value_cache], [got[0], got[1]], verbose=verbose, ) raise AssertionError( f"DynamicCache not fully implemented with classes " f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " f"and expected={string_type(expected)}, got={string_type(got)},\n" f"level={level}" ) if expected.__class__.__name__ in ("transformers.cache_utils.MambaCache", "MambaCache"): if verbose >= 6: print(f"[max_diff] MambaCache: {string_type(expected)} ? {string_type(got)}") if got.__class__.__name__ != expected.__class__.__name__: # This case happens with onnx where the outputs are flattened. return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) atts = [] for k in ["conv_states", "ssm_states"]: if hasattr(expected, k) and not hasattr(got, k): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) atts.append(k) return max_diff( [getattr(expected, k) for k in atts], [getattr(got, k) for k in atts], level=level, flatten=flatten, debug_info=debug_info, begin=begin, end=end, _index=_index, verbose=verbose, ) raise AssertionError( f"Not implemented with implemented with expected=" f"{string_type(expected)}, got={string_type(got)},\n" f"level={level}" )
[docs] def string_diff(diff: Dict[str, Any]) -> str: """Renders discrepancies return by :func:`max_diff` into one string.""" # dict(abs=, rel=, sum=, n=n_diff, dnan=) if diff.get("dnan", None): if diff["abs"] == 0 or diff["rel"] == 0: return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}" return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}, dnan={diff['dnan']}" if diff["abs"] == 0 or diff["rel"] == 0: return f"abs={diff['abs']}, rel={diff['rel']}" return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}"