Source code for yobx.helpers.helper

import enum
import inspect
from dataclasses import is_dataclass, fields
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np


def _string_tensor(obj, cls: str, with_shape: bool, with_device: bool) -> str:
    from .torch_helper import torch_dtype_to_onnx_dtype

    i = torch_dtype_to_onnx_dtype(obj.dtype)
    prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
    if not with_shape:
        return f"{prefix}{cls}{i}r{len(obj.shape)}"
    return f"{prefix}{cls}{i}s{'x'.join(map(str, obj.shape))}"


[docs] def string_type( obj: Any, with_shape: bool = False, with_min_max: bool = False, with_device: bool = False, ignore: bool = False, limit: int = 20, ) -> str: """ Displays the types of an object as a string. :param obj: any :param with_shape: displays shapes as well :param with_min_max: displays information about the values :param with_device: display the device :param ignore: if True, just prints the type for unknown types :return: str The function displays something like the following for a tensor. .. code-block:: text T7s2x7[0.5:6:A3.56] ^^^+-^^----+------^ || | | || | +-- information about the content of a tensor or array || | [min,max:A<average>] || | || +-- a shape || |+-- integer following the code defined by onnx.TensorProto, | 7 is onnx.TensorProto.INT64 (see onnx_dtype_name) | +-- A,T,F A is an array from numpy T is a Tensor from pytorch F is a FakeTensor from pytorch The element types for a tensor are displayed as integer to shorten the message. The semantic is defined by :class:`onnx.TensorProto` and can be obtained by :func:`yobx.helpers.onnx_helper.onnx_dtype_name`. .. runpython:: :showcode: from yobx.helpers import string_type print(string_type((1, ["r", 6.6]))) With pytorch: .. runpython:: :showcode: import torch from yobx.helpers import string_type inputs = ( torch.rand((3, 4), dtype=torch.float16), [ torch.rand((5, 6), dtype=torch.float16), torch.rand((5, 6, 7), dtype=torch.float16), ] ) # with shapes print(string_type(inputs, with_shape=True)) # with min max print(string_type(inputs, with_shape=True, with_min_max=True)) """ if obj is None: return "None" # tuple if isinstance(obj, tuple): if len(obj) == 1: s = string_type( obj[0], with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) return f"({s},)" if len(obj) < limit: js = ",".join( string_type( o, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) for o in obj ) return f"({js})" tt = string_type( obj[0], with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]" return f"#{len(obj)}({tt},...)" # list if isinstance(obj, list): if len(obj) < limit: js = ",".join( string_type( o, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) for o in obj ) return f"#{len(obj)}[{js}]" tt = string_type( obj[0], with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]" return f"#{len(obj)}[{tt},...]" # set if isinstance(obj, set): if len(obj) < 10: js = ",".join( string_type( o, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) for o in obj ) return f"{{{js}}}" if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj): mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj) return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]" return f"{{...}}#{len(obj)}" if with_shape else "{...}" # dict if isinstance(obj, dict) and type(obj) is dict: if len(obj) == 0: return "{}" import torch if all(isinstance(k, int) for k in obj) and all( isinstance( v, ( str, torch.export.dynamic_shapes._Dim, torch.export.dynamic_shapes._DerivedDim, torch.export.dynamic_shapes._DimHint, ), ) for v in obj.values() ): # This is dynamic shapes rows = [] for k, v in obj.items(): if isinstance(v, str): rows.append(f"{k}:DYN({v})") else: rows.append(f"{k}:{string_type(v)}") return f"{{{','.join(rows)}}}" kws = dict( with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items()) # type: ignore[arg-type] if all(isinstance(k, int) for k in obj): return f"{{{s}}}" return f"dict({s})" # array if isinstance(obj, np.ndarray): from .onnx_helper import np_dtype_to_tensor_dtype if with_min_max: s = string_type(obj, with_shape=with_shape) if len(obj.shape) == 0: return f"{s}={obj}" if obj.size == 0: return f"{s}[empty]" n_nan = np.isnan(obj.reshape((-1,))).astype(int).sum() if n_nan > 0: nob = obj.ravel() nob = nob[~np.isnan(nob)] if nob.size == 0: return f"{s}[N{n_nan}nans]" return f"{s}[{nob.min()},{nob.max()}:A{nob.astype(float).mean()}N{n_nan}nans]" return f"{s}[{obj.min()},{obj.max()}:A{obj.astype(float).mean()}]" i = np_dtype_to_tensor_dtype(obj.dtype) if not with_shape: return f"A{i}r{len(obj.shape)}" return f"A{i}s{'x'.join(map(str, obj.shape))}" import torch # Dim, SymInt if isinstance(obj, torch.export.dynamic_shapes._DerivedDim): return "DerivedDim" if isinstance(obj, torch.export.dynamic_shapes._Dim): return f"Dim({obj.__name__})" if isinstance(obj, torch.SymInt): return "SymInt" if isinstance(obj, torch.SymFloat): return "SymFloat" if isinstance(obj, torch.export.dynamic_shapes._DimHint): cl = ( torch.export.dynamic_shapes._DimHintType if hasattr(torch.export.dynamic_shapes, "_DimHintType") else torch.export.Dim ) if obj in (torch.export.Dim.DYNAMIC, cl.DYNAMIC): return "DYNAMIC" if obj in (torch.export.Dim.AUTO, cl.AUTO): return "AUTO" return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO") if isinstance(obj, bool): if with_min_max: return f"bool={obj}" return "bool" if isinstance(obj, int): if with_min_max: return f"int={obj}" return "int" if isinstance(obj, float): if with_min_max: return f"float={obj}" return "float" if isinstance(obj, str): return "str" if isinstance(obj, slice): return "slice" if is_dataclass(obj): # That includes torch.export.Dim.AUTO, torch.export.Dim.DYNAMIC so they need to be # handled before that. values = {f.name: getattr(obj, f.name, None) for f in fields(obj)} values = {k: v for k, v in values.items() if v is not None} s = string_type( values, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, ignore=ignore, limit=limit, ) return f"{obj.__class__.__name__}{s[4:]}" # Tensors if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor): return _string_tensor(obj, "F", with_shape, with_device) if isinstance(obj, torch.Tensor): from .torch_helper import torch_dtype_to_onnx_dtype if with_min_max: s = string_type(obj, with_shape=with_shape, with_device=with_device) if len(obj.shape) == 0: return f"{s}={obj}" if obj.numel() == 0: return f"{s}[empty]" n_nan = obj.reshape((-1,)).isnan().to(int).sum() if n_nan > 0: nob = obj.reshape((-1,)) nob = nob[~nob.isnan()] if obj.dtype in {torch.complex64, torch.complex128}: return f"{s}[{nob.abs().min()},{nob.abs().max():A{nob.mean()}N{n_nan}nans}]" return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}N{n_nan}nans]" if obj.dtype in {torch.complex64, torch.complex128}: return f"{s}[{obj.abs().min()},{obj.abs().max()}:A{obj.abs().mean()}]" return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}]" i = torch_dtype_to_onnx_dtype(obj.dtype) prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else "" if not with_shape: return f"{prefix}T{i}r{len(obj.shape)}" return f"{prefix}T{i}s{'x'.join(map(str, obj.shape))}" if obj.__class__.__name__ == "OrtValue": if not obj.has_value(): return "OV(<novalue>)" if not obj.is_tensor(): return "OV(NOTENSOR)" if with_min_max: try: t = obj.numpy() except Exception: return "OV(NO-NUMPY:FIXIT)" dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else "" return f"{dev}OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})" dt = obj.element_type() shape = obj.shape() dev = ("G" if obj.device_name() == "Cuda" else "C") if with_device else "" if with_shape: return f"{dev}OV{dt}s{'x'.join(map(str, shape))}" return f"{dev}OV{dt}r{len(shape)}" if obj.__class__.__name__ == "SymbolicTensor": return _string_tensor(obj, "ST", with_shape, with_device) if ( obj.__class__.__name__ in {"DynamicCache"} and hasattr(obj, "layers") and any(lay.__class__.__name__ != "DynamicLayer" for lay in obj.layers) ): slay = [] for lay in obj.layers: skeys = string_type( lay.keys, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) svalues = string_type( lay.keys, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) slay.append(f"{lay.__class__.__name__}({skeys}, {svalues})") return f"{obj.__class__.__name__}({', '.join(slay)})" if obj.__class__.__name__ in { "DynamicCache", "SlidingWindowCache", "StaticCache", "HybridCache", }: from .cache_helper import CacheKeyValue ca = CacheKeyValue(obj) kc = string_type( ca.key_cache, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) vc = string_type( ca.value_cache, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})" if obj.__class__.__name__ == "StaticLayer": kc = string_type( list(obj.keys), with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) vc = string_type( list(obj.values), with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"{obj.__class__.__name__}(keys={kc}, values={vc})" if obj.__class__.__name__ == "EncoderDecoderCache": att = string_type( obj.self_attention_cache, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) cross = string_type( obj.cross_attention_cache, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return ( f"{obj.__class__.__name__}(self_attention_cache={att}, " f"cross_attention_cache={cross})" ) if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES: from .cache_helper import flatten_unflatten_for_dynamic_shapes args = flatten_unflatten_for_dynamic_shapes(obj) att = string_type( args, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"{obj.__class__.__name__}[serialized]({att})" if type(obj).__name__ == "Node" and hasattr(obj, "meta"): # torch.fx.node.Node return f"%{obj.target}" if type(obj).__name__ == "ValueInfoProto": return f"OT{obj.type.tensor_type.elem_type}" if obj.__class__.__name__ == "BatchFeature": s = string_type( obj.data, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"BatchFeature(data={s})" if obj.__class__.__name__ == "BatchEncoding": s = string_type( obj.data, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"BatchEncoding(data={s})" if obj.__class__.__name__ == "VirtualTensor": def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821 if isinstance(value, str): return value if hasattr(value, "node") and isinstance(value.node, str): return f"{value.node}" from torch.fx.experimental.sym_node import SymNode if hasattr(value, "node") and isinstance(value.node, SymNode): # '_expr' is safer than expr return str(value.node._expr).replace(" ", "") try: val_int = int(value) return val_int except ( TypeError, ValueError, AttributeError, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode, ): pass raise AssertionError(f"Unable to convert {value!r} into string") return ( f"{obj.__class__.__name__}(name={obj.name!r}, " f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})" ) if obj.__class__.__name__ == "KeyValuesWrapper": import transformers assert isinstance( obj, transformers.cache_utils.KeyValuesWrapper ), f"Unexpected type {type(obj)}" s = string_type( list(obj), with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"{obj.__class__.__name__}[{obj.cache_type}]{s}" if obj.__class__.__name__ == "DynamicLayer": import transformers assert isinstance( obj, transformers.cache_utils.DynamicLayer ), f"Unexpected type {type(obj)}" s1 = string_type( obj.keys, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) s2 = string_type( obj.values, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"{obj.__class__.__name__}(keys={s1}, values={s2})" if isinstance(obj, torch.nn.Module): return f"{obj.__class__.__name__}(...)" if isinstance(obj, (torch.device, torch.dtype, torch.memory_format, torch.layout)): return f"{obj.__class__.__name__}({obj})" if isinstance( # TreeSpec, MappingKey, SequenceKey obj, ( torch.utils._pytree.TreeSpec, torch.utils._pytree.MappingKey, torch.utils._pytree.SequenceKey, ), ): return repr(obj).replace(" ", "").replace("\n", " ") if isinstance(obj, torch.fx.proxy.Proxy): return repr(obj) if ignore: return f"{obj.__class__.__name__}(...)" if obj.__class__.__name__.endswith("Config"): import transformers.configuration_utils as tcu if isinstance(obj, tcu.PretrainedConfig): s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "") return f"{obj.__class__.__name__}(**{s})" if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}: return f"{obj.__class__.__name__}(...)" if obj.__class__.__name__ == "Results": import ultralytics assert isinstance(obj, ultralytics.engine.results.Results), f"Unexpected type={type(obj)}" return f"ultralytics.{obj.__class__.__name__}(...)" if obj.__class__.__name__ == "FakeTensorMode": return f"{obj}" if obj.__class__.__name__ == "FakeTensorContext": return "FakeTensorContext(...)" if obj.__class__.__name__ == "Chat": import transformers.utils.chat_template_utils as ctu assert isinstance(obj, ctu.Chat), f"unexpected type {type(obj)}" msg = string_type( obj.messages, with_shape=with_shape, with_min_max=with_min_max, with_device=with_device, limit=limit, ) return f"Chat({msg})" raise TypeError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
[docs] def string_signature(sig: Any) -> str: """Displays the signature of a functions.""" def _k(p, kind): for name in dir(p): if getattr(p, name) == kind: return name return repr(kind) text = [" __call__ ("] for p in sig.parameters: pp = sig.parameters[p] kind = repr(pp.kind) t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p if pp.default is not inspect._empty: t = f"{t} = {pp.default!r}" if kind == pp.VAR_POSITIONAL: t = f"*{t}" le = (30 - len(t)) * " " text.append(f" {t}{le}|{_k(pp,kind)}") text.append( f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")" ) return "\n".join(text)
[docs] def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str: """ Displays the signature of a function if the default if the given value is different from """ if hasattr(f, "__init__") and kwargs is None: fct = f.__init__ kwargs = f.__dict__ name = f.__class__.__name__ else: fct = f name = f.__name__ if kwargs is None: kwargs = {} rows = [] sig = inspect.signature(fct) for p in sig.parameters: pp = sig.parameters[p] d = pp.default if d is inspect._empty: if p in kwargs: v = kwargs[p] rows.append(f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}") continue v = kwargs.get(p, d) if d != v: rows.append(f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}") continue atts = ", ".join(rows) return f"{name}({atts})"
[docs] def make_hash(obj: Any) -> str: """ Returns a simple hash of ``id(obj)`` in four letter. """ aa = id(obj) % (26**3) return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"
[docs] def flatten_object(x: Any, drop_keys: bool = False) -> Any: """ Flattens the object. It accepts some common classes used in deep learning. :param x: any object :param drop_keys: drop the keys if a dictionary is flattened. Keeps the order defined by the dictionary if False, sort them if True. :return: flattened object """ if x is None: return x if isinstance(x, (list, tuple)): res = [] for i in x: if i is None or hasattr(i, "shape") or isinstance(i, (int, float, str)): res.append(i) else: res.extend(flatten_object(i, drop_keys=drop_keys)) return tuple(res) if isinstance(x, tuple) else res if isinstance(x, dict): # We flatten the keys. if drop_keys: return flatten_object(list(x.values()), drop_keys=drop_keys) return flatten_object(list(x.items()), drop_keys=drop_keys) if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}: from .cache_helper import CacheKeyValue return CacheKeyValue(x).aslist() if x.__class__.__name__ == "EncoderDecoderCache": res = [ *flatten_object(x.self_attention_cache), *flatten_object(x.cross_attention_cache), ] return tuple(res) if hasattr(x, "to_tuple"): return flatten_object(x.to_tuple(), drop_keys=drop_keys) if hasattr(x, "shape"): # A tensor. Nothing to do. return x raise TypeError( f"Unexpected type {type(x)} for x, drop_keys={drop_keys}, " f"content is {string_type(x, with_shape=True)}" )
def _make_debug_info(msg, level, debug_info) -> Optional[List[str]]: return [*(debug_info if debug_info else []), f"{' ' * level}{msg}"]
[docs] def max_diff( expected: Any, got: Any, level: int = 0, flatten: bool = False, debug_info: Optional[List[str]] = None, begin: int = 0, end: int = -1, _index: int = 0, allow_unique_tensor_with_list_of_one_element: bool = True, hist: Optional[Union[bool, List[float]]] = None, skip_none: bool = False, ) -> Dict[str, Union[float, int, Tuple[Any, ...]]]: """ Returns the maximum discrepancy. :param expected: expected values :param got: values :param level: for embedded outputs, used for debug purpposes :param flatten: flatten outputs :param debug_info: debug information :param begin: first output to considered :param end: last output to considered (-1 for the last one) :param _index: used with begin and end :param allow_unique_tensor_with_list_of_one_element: allow a comparison between a single tensor and a list of one tensor :param hist: compute an histogram of the discrepancies :param skip_none: skips none value :return: dictionary with many values * abs: max absolute error * rel: max relative error * sum: sum of the errors * n: number of outputs values, if there is one output, this number will be the number of elements of this output * dnan: difference in the number of nan * dev: tensor on the same device, if applicable You may use :func:`string_diff` to display the discrepancies in one string. """ if expected is None and got is None: return dict(abs=0, rel=0, sum=0, n=0, dnan=0) _dkws_ = dict( level=level + 1, begin=begin, end=end, _index=_index, hist=hist, skip_none=skip_none, ) _dkws = {**_dkws_, "flatten": flatten} _dkwsf = {**_dkws_, "flatten": False} _debug = lambda msg: _make_debug_info(msg, level, debug_info) # noqa: E731 if allow_unique_tensor_with_list_of_one_element: if hasattr(expected, "shape") and isinstance(got, (list, tuple)) and len(got) == 1: return max_diff( expected, got[0], level=level, flatten=False, debug_info=debug_info, allow_unique_tensor_with_list_of_one_element=False, hist=hist, skip_none=skip_none, ) return max_diff( expected, got, level=level, flatten=flatten, debug_info=debug_info, begin=begin, end=end, _index=_index, allow_unique_tensor_with_list_of_one_element=False, hist=hist, skip_none=skip_none, ) if expected.__class__.__name__ == "CausalLMOutputWithPast": if got.__class__.__name__ == "CausalLMOutputWithPast": return max_diff( [expected.logits, *flatten_object(expected.past_key_values)], [got.logits, *flatten_object(got.past_key_values)], debug_info=_debug(expected.__class__.__name__), **_dkws, ) return max_diff( [expected.logits, *flatten_object(expected.past_key_values)], got, debug_info=_debug(expected.__class__.__name__), **_dkws, ) if hasattr(expected, "to_tuple"): return max_diff(expected.to_tuple(), got, debug_info=_debug("to_tuple1"), **_dkws) if hasattr(got, "to_tuple"): return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws) if isinstance(expected, (tuple, list)): if len(expected) == 1 and not isinstance(got, type(expected)): return max_diff(expected[0], got, debug_info=_debug("lt2"), **_dkws) if not isinstance(got, (tuple, list)): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if len(got) != len(expected): if flatten: # Let's flatten. flat_a = flatten_object(expected, drop_keys=True) flat_b = flatten_object(got, drop_keys=True) return max_diff( flat_a, flat_b, debug_info=[ *(debug_info if debug_info else []), (f"{' ' * level}flatten[{string_type(expected)},{string_type(got)}]"), ], **_dkwsf, ) return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) am, rm, sm, n, dn, drep, dd = 0, 0, 0.0, 0.0, 0, None, None for ip, (e, g) in enumerate(zip(expected, got)): d = max_diff( e, g, level=level + 1, debug_info=[ *(debug_info if debug_info else []), f"{' ' * level}[{ip}] so far abs {am} - rel {rm}", ], begin=begin, end=end, _index=_index + ip, flatten=flatten, hist=hist, skip_none=skip_none, ) am = max(am, d["abs"]) dn = max(dn, d["dnan"]) rm = max(rm, d["rel"]) sm += d["sum"] # type: ignore n += d["n"] # type: ignore if "rep" in d: if drep is None: drep = d["rep"].copy() else: for k, v in d["rep"].items(): drep[k] += v if "dev" in d and d["dev"] is not None: if dd is None: dd = d["dev"] else: dd += d["dev"] # type: ignore[operator] res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn) if dd is not None: res["dev"] = dd if drep: res["rep"] = drep return res # type: ignore if isinstance(expected, dict): assert begin == 0 and end == -1, ( f"begin={begin}, end={end} not compatible with dictionaries, " f"keys={sorted(expected)}" ) if isinstance(got, dict): if len(expected) != len(got): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if set(expected) != set(got): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) keys = sorted(expected) return max_diff( [expected[k] for k in keys], [got[k] for k in keys], debug_info=_debug("dict1"), **_dkws, ) if not isinstance(got, (tuple, list)): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if len(expected) != len(got): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) return max_diff(list(expected.values()), got, debug_info=_debug("dict2"), **_dkws) if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray): dev = None import torch if isinstance(expected, torch.Tensor): from .torch_helper import to_numpy dev = 0 if expected.device.type == "cpu" else 1 expected = to_numpy(expected) if isinstance(got, torch.Tensor): from .torch_helper import to_numpy dev = 0 if got.device.type == "cpu" else 1 got = to_numpy(got) if isinstance(got, (list, tuple)): got = np.array(got) if isinstance(expected, (list, tuple)): expected = np.array(expected) if _index < begin or (end != -1 and _index >= end): # out of boundary res = dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) if dev is not None: res["dev"] = dev # type: ignore[operator] return res # type: ignore[return-value] if isinstance(expected, (int, float)): if isinstance(got, np.ndarray) and len(got.shape) == 0: got = float(got) if isinstance(got, (int, float)): if expected == got: return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0) res = dict( abs=abs(expected - got), rel=abs(expected - got) / (abs(expected) + 1e-5), sum=abs(expected - got), n=1, dnan=0, ) if dev is not None: res["dev"] = dev return res # type: ignore[return-value] return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if expected.dtype in (np.complex64, np.complex128): if got.dtype == expected.dtype: got = np.real(got) elif got.dtype not in (np.float32, np.float64): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) expected = np.real(expected) if expected.shape != got.shape: return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) # nan are replace by 1e10, any discrepancies in that order of magnitude # is likely caused by nans exp_cpu = np.nan_to_num(expected.astype(np.float64), nan=1e10) got_cpu = np.nan_to_num(got.astype(np.float64), nan=1e10) diff = np.abs(got_cpu - exp_cpu) ndiff = np.abs(np.isnan(expected).astype(int) - np.isnan(got).astype(int)) rdiff = diff / (np.abs(exp_cpu) + 1e-3) if diff.size == 0: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( (0, 0, 0, 0, 0) if exp_cpu.size == got_cpu.size else (np.inf, np.inf, np.inf, 0, np.inf) ) argm = None else: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( float(diff.max()), float(rdiff.max()), float(diff.sum()), float(diff.size), float(ndiff.sum()), ) argm = tuple(map(int, np.unravel_index(diff.argmax(), diff.shape))) res: Dict[str, float] = dict( # type: ignore abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm ) if dev is not None: res["dev"] = dev if hist: if isinstance(hist, bool): hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype) res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist} return res # type: ignore import torch if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor): dev = 0 if expected.device == got.device else 1 if _index < begin or (end != -1 and _index >= end): # out of boundary return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0, dev=dev) if expected.dtype in (torch.complex64, torch.complex128): if got.dtype == expected.dtype: got = torch.view_as_real(got) elif got.dtype not in (torch.float32, torch.float64): return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) expected = torch.view_as_real(expected) if expected.shape != got.shape: return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) # nan are replace by 1e10, any discrepancies in that order of magnitude # is likely caused by nans exp_cpu = expected.to(torch.float64).nan_to_num(1e10) got_cpu = got.to(torch.float64).nan_to_num(1e10) if got_cpu.device != exp_cpu.device: if torch.device("cuda:0") in {got_cpu.device, exp_cpu.device}: got_cpu = got_cpu.to("cuda:0") exp_cpu = exp_cpu.to("cuda:0") expected = expected.to("cuda:0") got = got.to("cuda:0") else: got_cpu = got_cpu.detach().to("cpu") exp_cpu = exp_cpu.detach().to("cpu") expected = expected.to("cpu") got = got.to("cpu") diff = (got_cpu - exp_cpu).abs() ndiff = (expected.isnan().to(int) - got.isnan().to(int)).abs() rdiff = diff / (exp_cpu.abs() + 1e-3) if diff.numel() > 0: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( float(diff.max().detach()), float(rdiff.max().detach()), float(diff.sum().detach()), float(diff.numel()), float(ndiff.sum().detach()), ) argm = tuple(map(int, torch.unravel_index(diff.argmax(), diff.shape))) elif got_cpu.numel() == exp_cpu.numel(): abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (0.0, 0.0, 0.0, 0.0, 0.0) argm = None else: abs_diff, rel_diff, sum_diff, n_diff, nan_diff = ( np.inf, np.inf, np.inf, np.inf, np.inf, ) argm = None res: Dict[str, float] = dict( # type: ignore abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm, dev=dev, ) if hist: if isinstance(hist, bool): hist = torch.tensor([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype) res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist} return res # type: ignore if isinstance(expected, int) and isinstance(got, torch.Tensor): # a size if got.shape != tuple(): return dict( # type: ignore abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf, argm=np.inf, ) return dict( # type: ignore abs=abs(expected - got.item()), rel=abs((expected - got.item()) / max(1, expected)), sum=abs(expected - got.item()), n=1, dnan=0, ) if "SquashedNormal" in expected.__class__.__name__: values = (expected.mean, expected.scale) return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws) if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES: if got.__class__ not in torch.utils._pytree.SUPPORTED_NODES: return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) expected_args, _spec = torch.utils._pytree.tree_flatten(expected) got_args, _spec = torch.utils._pytree.tree_flatten(got) return max_diff( expected_args, got_args, debug_info=_debug(expected.__class__.__name__), **_dkws ) # backup function in case pytorch does not know how to serialize. if expected.__class__.__name__ == "DynamicCache": if got.__class__.__name__ == "DynamicCache": from .cache_helper import CacheKeyValue expected = CacheKeyValue(expected) got = CacheKeyValue(got) return max_diff( [expected.key_cache, expected.value_cache], [got.key_cache, got.value_cache], hist=hist, ) if isinstance(got, tuple) and len(got) == 2: from .cache_helper import CacheKeyValue if not isinstance(expected, CacheKeyValue): expected = CacheKeyValue(expected) return max_diff( [expected.key_cache, expected.value_cache], [got[0], got[1]], debug_info=_debug(expected.__class__.__name__), **_dkws, ) raise AssertionError( f"DynamicCache not fully implemented with classes " f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " f"and expected={string_type(expected)}, got={string_type(got)},\n" f"level={level}" ) if expected.__class__.__name__ == "StaticCache": if got.__class__.__name__ == "StaticCache": from .cache_helper import CacheKeyValue cae = CacheKeyValue(expected) cag = CacheKeyValue(got) return max_diff( [cae.key_cache, cae.value_cache], [cag.key_cache, cag.value_cache], hist=hist, ) if isinstance(got, tuple) and len(got) == 2: from .cache_helper import CacheKeyValue cae = CacheKeyValue(expected) return max_diff( [cae.key_cache, cae.value_cache], [got[0], got[1]], debug_info=_debug(expected.__class__.__name__), **_dkws, ) raise AssertionError( f"StaticCache not fully implemented with classes " f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " f"and expected={string_type(expected)}, got={string_type(got)},\n" f"level={level}" ) if expected.__class__.__name__ == "CacheKeyValue": from .cache_helper import CacheKeyValue if got.__class__.__name__ == "CacheKeyValue": return max_diff( [expected.key_cache, expected.value_cache], [got.key_cache, got.value_cache], hist=hist, ) if isinstance(got, tuple) and len(got) == 2: return max_diff( [expected.key_cache, expected.value_cache], [got[0], got[1]], debug_info=_debug(expected.__class__.__name__), **_dkws, ) raise AssertionError( f"CacheKeyValue not fully implemented with classes " f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " f"and expected={string_type(expected)}, got={string_type(got)},\n" f"level={level}" ) if expected.__class__.__name__ == "EncoderDecoderCache": if got.__class__.__name__ == "EncoderDecoderCache": return max_diff( [expected.self_attention_cache, expected.cross_attention_cache], [got.self_attention_cache, got.cross_attention_cache], hist=hist, ) if isinstance(got, tuple) and len(got) == 2: return max_diff( [expected.self_attention_cache, expected.cross_attention_cache], [got[0], got[1]], debug_info=_debug(expected.__class__.__name__), **_dkws, ) raise AssertionError( f"EncoderDecoderCache not fully implemented with classes " f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, " f"and expected={string_type(expected)}, got={string_type(got)},\n" f"level={level}" ) if expected.__class__.__name__ == "KeyValuesWrapper": if got.__class__.__name__ != expected.__class__.__name__: return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) if got.cache_type != expected.cache_type: return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf) return max_diff( list(expected), list(got), debug_info=_debug(expected.__class__.__name__), **_dkws, ) if skip_none and (expected is None or got is None): return {"abs": 0, "rel": 0, "dnan": 0, "n": 0, "sum": 0} raise AssertionError( f"Not implemented with implemented with expected=" f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n" f"level={level}" )