import ast
import enum
import inspect
from typing import Any, Callable, Dict, List, Optional, Set
import numpy as np
def size_type(dtype: Any) -> int:
"""Returns the element size for an element type."""
if isinstance(dtype, int):
from onnx import TensorProto
# It is a TensorProto.DATATYPE
if dtype in {
TensorProto.DOUBLE,
TensorProto.INT64,
TensorProto.UINT64,
TensorProto.COMPLEX64,
}:
return 8
if dtype in {TensorProto.FLOAT, TensorProto.INT32, TensorProto.UINT32}:
return 4
if dtype in {
TensorProto.FLOAT16,
TensorProto.BFLOAT16,
TensorProto.INT16,
TensorProto.UINT16,
}:
return 2
if dtype in {
TensorProto.INT8,
TensorProto.UINT8,
TensorProto.BOOL,
TensorProto.FLOAT8E4M3FN,
TensorProto.FLOAT8E4M3FNUZ,
TensorProto.FLOAT8E5M2,
TensorProto.FLOAT8E5M2FNUZ,
}:
return 1
if dtype in {TensorProto.COMPLEX128}:
return 16
from .helpers.onnx_helper import onnx_dtype_name
raise AssertionError(
f"Unable to return the element size for type {onnx_dtype_name(dtype)}"
)
if dtype == np.float64 or dtype == np.int64:
return 8
if dtype == np.float32 or dtype == np.float32:
return 4
if dtype == np.float16 or dtype == np.int16:
return 2
if dtype == np.int32:
return 4
if dtype == np.int8:
return 1
if hasattr(np, "uint64"):
# it fails on mac
if dtype == np.uint64:
return 8
if dtype == np.uint32:
return 4
if dtype == np.uint16:
return 2
if dtype == np.uint8:
return 1
import torch
if dtype in {torch.float64, torch.int64}:
return 8
if dtype in {torch.float32, torch.int32}:
return 4
if dtype in {torch.float16, torch.int16, torch.bfloat16}:
return 2
if dtype in {torch.int8, torch.uint8, torch.bool}:
return 1
if hasattr(torch, "uint64"):
# it fails on mac
if dtype in {torch.uint64}:
return 8
if dtype in {torch.uint32}:
return 4
if dtype in {torch.uint16}:
return 2
import ml_dtypes
if dtype == ml_dtypes.bfloat16:
return 2
raise AssertionError(f"Unexpected dtype={dtype}")
[docs]
def string_type(
obj: Any,
with_shape: bool = False,
with_min_max: bool = False,
with_device: bool = False,
ignore: bool = False,
limit: int = 10,
) -> str:
"""
Displays the types of an object as a string.
:param obj: any
:param with_shape: displays shapes as well
:param with_min_max: displays information about the values
:param with_device: display the device
:param ignore: if True, just prints the type for unknown types
:return: str
.. runpython::
:showcode:
from onnx_diagnostic.helpers import string_type
print(string_type((1, ["r", 6.6])))
With pytorch:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
inputs = (
torch.rand((3, 4), dtype=torch.float16),
[
torch.rand((5, 6), dtype=torch.float16),
torch.rand((5, 6, 7), dtype=torch.float16),
]
)
# with shapes
print(string_type(inputs, with_shape=True))
# with min max
print(string_type(inputs, with_shape=True, with_min_max=True))
"""
if obj is None:
return "None"
# tuple
if isinstance(obj, tuple):
if len(obj) == 1:
s = string_type(
obj[0],
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
return f"({s},)"
if len(obj) < limit:
js = ",".join(
string_type(
o,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
for o in obj
)
return f"({js})"
tt = string_type(
obj[0],
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
return f"#{len(obj)}({tt},...)"
# list
if isinstance(obj, list):
if len(obj) < limit:
js = ",".join(
string_type(
o,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
for o in obj
)
return f"#{len(obj)}[{js}]"
tt = string_type(
obj[0],
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]"
return f"#{len(obj)}[{tt},...]"
# set
if isinstance(obj, set):
if len(obj) < 10:
js = ",".join(
string_type(
o,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
for o in obj
)
return f"{{{js}}}"
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]"
return f"{{...}}#{len(obj)}" if with_shape else "{...}"
# dict
if isinstance(obj, dict):
if len(obj) == 0:
return "{}"
kws = dict(
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
ignore=ignore,
limit=limit,
)
s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items())
return f"dict({s})"
# arrat
if isinstance(obj, np.ndarray):
from .onnx_helper import np_dtype_to_tensor_dtype
if with_min_max:
s = string_type(obj, with_shape=with_shape)
if len(obj.shape) == 0:
return f"{s}={obj}"
if obj.size == 0:
return f"{s}[empty]"
n_nan = np.isnan(obj.reshape((-1,))).astype(int).sum()
if n_nan > 0:
nob = obj.ravel()
nob = nob[~np.isnan(nob)]
if nob.size == 0:
return f"{s}[N{n_nan}nans]"
return f"{s}[{nob.min()},{nob.max()}:A{nob.astype(float).mean()}N{n_nan}nans]"
return f"{s}[{obj.min()},{obj.max()}:A{obj.astype(float).mean()}]"
i = np_dtype_to_tensor_dtype(obj.dtype)
if not with_shape:
return f"A{i}r{len(obj.shape)}"
return f"A{i}s{'x'.join(map(str, obj.shape))}"
import torch
# Dim, SymInt
if isinstance(obj, torch.export.dynamic_shapes._DerivedDim):
return "DerivedDim"
if isinstance(obj, torch.export.dynamic_shapes._Dim):
return "Dim"
if isinstance(obj, torch.SymInt):
return "SymInt"
if isinstance(obj, torch.SymFloat):
return "SymFloat"
# Tensors
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
from .onnx_helper import torch_dtype_to_onnx_dtype
i = torch_dtype_to_onnx_dtype(obj.dtype)
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
if not with_shape:
return f"{prefix}F{i}r{len(obj.shape)}"
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
if isinstance(obj, torch.Tensor):
from .onnx_helper import torch_dtype_to_onnx_dtype
if with_min_max:
s = string_type(obj, with_shape=with_shape, with_device=with_device)
if len(obj.shape) == 0:
return f"{s}={obj}"
if obj.numel() == 0:
return f"{s}[empty]"
n_nan = obj.reshape((-1,)).isnan().to(int).sum()
if n_nan > 0:
nob = obj.reshape((-1,))
nob = nob[~nob.isnan()]
if obj.dtype in {torch.complex64, torch.complex128}:
return (
f"{s}[{nob.abs().min()},{nob.abs().max():A{nob.mean()}N{n_nan}nans}]"
)
return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}N{n_nan}nans]"
if obj.dtype in {torch.complex64, torch.complex128}:
return f"{s}[{obj.abs().min()},{obj.abs().max()}:A{obj.abs().mean()}]"
return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}]"
i = torch_dtype_to_onnx_dtype(obj.dtype)
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
if not with_shape:
return f"{prefix}T{i}r{len(obj.shape)}"
return f"{prefix}T{i}s{'x'.join(map(str, obj.shape))}"
if obj.__class__.__name__ == "OrtValue":
if not obj.has_value():
return "OV(<novalue>)"
if not obj.is_tensor():
return "OV(NOTENSOR)"
if with_min_max:
try:
t = obj.numpy()
except Exception:
# pass unable to convert into numpy (bfloat16, ...)
return "OV(NO-NUMPY:FIXIT)"
return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
dt = obj.element_type()
shape = obj.shape()
if with_shape:
return f"OV{dt}s{'x'.join(map(str, shape))}"
return f"OV{dt}r{len(shape)}"
if isinstance(obj, bool):
if with_min_max:
return f"bool={obj}"
return "bool"
if isinstance(obj, int):
if with_min_max:
return f"int={obj}"
return "int"
if isinstance(obj, float):
if with_min_max:
return f"float={obj}"
return "float"
if isinstance(obj, str):
return "str"
if isinstance(obj, slice):
return "slice"
# others classes
if type(obj).__name__ == "MambaCache":
c = string_type(
obj.conv_states,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
d = string_type(
obj.ssm_states,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"MambaCache(conv_states={c}, ssm_states={d})"
if type(obj).__name__ == "Node" and hasattr(obj, "meta"):
# torch.fx.node.Node
return f"%{obj.target}"
if type(obj).__name__ == "ValueInfoProto":
return f"OT{obj.type.tensor_type.elem_type}"
if obj.__class__.__name__ == "DynamicCache":
kc = string_type(
obj.key_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
vc = string_type(
obj.value_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
if obj.__class__.__name__ == "EncoderDecoderCache":
att = string_type(
obj.self_attention_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
cross = string_type(
obj.cross_attention_cache,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return (
f"{obj.__class__.__name__}(self_attention_cache={att}, "
f"cross_attention_cache={cross})"
)
if obj.__class__.__name__ == "BatchFeature":
s = string_type(
obj.data,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"BatchFeature(data={s})"
if obj.__class__.__name__ == "BatchEncoding":
s = string_type(
obj.data,
with_shape=with_shape,
with_min_max=with_min_max,
with_device=with_device,
limit=limit,
)
return f"BatchEncoding(data={s})"
if obj.__class__.__name__ == "VirtualTensor":
return (
f"{obj.__class__.__name__}(name={obj.name!r}, "
f"dtype={obj.dtype}, shape={obj.shape})"
)
if obj.__class__.__name__ == "_DimHint":
return str(obj)
if isinstance(obj, torch.nn.Module):
return f"{obj.__class__.__name__}(...)"
if isinstance(obj, (torch.device, torch.dtype, torch.memory_format, torch.layout)):
return f"{obj.__class__.__name__}({obj})"
if isinstance(obj, torch.utils._pytree.TreeSpec):
return repr(obj).replace(" ", "").replace("\n", " ")
if ignore:
return f"{obj.__class__.__name__}(...)"
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
def string_signature(sig: Any) -> str:
"""Displays the signature of a functions."""
def _k(p, kind):
for name in dir(p):
if getattr(p, name) == kind:
return name
return repr(kind)
text = [" __call__ ("]
for p in sig.parameters:
pp = sig.parameters[p]
kind = repr(pp.kind)
t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p
if pp.default is not inspect._empty:
t = f"{t} = {pp.default!r}"
if kind == pp.VAR_POSITIONAL:
t = f"*{t}"
le = (30 - len(t)) * " "
text.append(f" {t}{le}|{_k(pp,kind)}")
text.append(
f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")"
)
return "\n".join(text)
[docs]
def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str:
"""
Displays the signature of a function if the default
if the given value is different from
"""
if hasattr(f, "__init__") and kwargs is None:
fct = f.__init__
kwargs = f.__dict__
name = f.__class__.__name__
else:
fct = f
name = f.__name__
if kwargs is None:
kwargs = {}
rows = []
sig = inspect.signature(fct)
for p in sig.parameters:
pp = sig.parameters[p]
d = pp.default
if d is inspect._empty:
if p in kwargs:
v = kwargs[p]
rows.append(
f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}"
)
continue
v = kwargs.get(p, d)
if d != v:
rows.append(f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}")
continue
atts = ", ".join(rows)
return f"{name}({atts})"
def make_hash(obj: Any) -> str:
"""
Returns a simple hash of ``id(obj)`` in four letter.
"""
aa = id(obj) % (26**3)
return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"
def rename_dynamic_dimensions(
constraints: Dict[str, Set[str]], original: Set[str], ban_prefix: str = "DYN"
) -> Dict[str, str]:
"""
Renames dynamic shapes as requested by the user. :func:`torch.export.export` uses
many names for dynamic dimensions. When building the onnx model,
some of them are redundant and can be replaced by the name provided by the user.
:param constraints: exhaustive list of used names and all the values equal to it
:param original: the names to use if possible
:param ban_prefix: avoid any rewriting by a constant starting with this prefix
:return: replacement dictionary
"""
replacements = {s: s for s in original}
all_values = set(constraints) | original
not_done = set(constraints)
max_iter = len(replacements)
while not_done and max_iter > 0:
max_iter -= 1
for k, v in constraints.items():
common = v & original
if not common:
continue
sorted_common = sorted(common)
by = sorted_common[0]
if ban_prefix and by.startswith(ban_prefix):
continue
replacements[k] = by
for vv in v:
if vv not in replacements:
replacements[vv] = by
not_done = all_values - set(replacements)
return replacements
def rename_dynamic_expression(expression: str, replacements: Dict[str, str]):
"""
Renames variables of an expression.
:param expression: something like ``s15 + seq_length``
:param replacements: replacements to make
:return: new string
"""
class RenameVariable(ast.NodeTransformer):
def visit_Name(self, node):
if node.id in replacements:
node.id = replacements[node.id]
return node
tree = ast.parse(expression)
transformer = RenameVariable()
new_tree = transformer.visit(tree)
return ast.unparse(new_tree)
def flatten_object(x: Any, drop_keys: bool = False) -> Any:
"""
Flattens the object.
It accepts some common classes used in deep learning.
:param x: any object
:param drop_keys: drop the keys if a dictionary is flattened.
Keeps the order defined by the dictionary if False, sort them if True.
:return: flattened object
"""
if x is None:
return x
if isinstance(x, (list, tuple)):
res = []
for i in x:
if i is None or hasattr(i, "shape") or isinstance(i, (int, float, str)):
res.append(i)
else:
res.extend(flatten_object(i, drop_keys=drop_keys))
return tuple(res) if isinstance(x, tuple) else res
if isinstance(x, dict):
# We flatten the keys.
if drop_keys:
return flatten_object(list(x.values()), drop_keys=drop_keys)
return flatten_object(list(x.items()), drop_keys=drop_keys)
if x.__class__.__name__ == "DynamicCache":
res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
return tuple(res)
if x.__class__.__name__ == "EncoderDecoderCache":
res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
return tuple(res)
if x.__class__.__name__ == "MambaCache":
if isinstance(x.conv_states, list):
res = flatten_object(x.conv_states) + flatten_object(x.ssm_states)
return tuple(res)
return (x.conv_states, x.ssm_states)
if hasattr(x, "to_tuple"):
return flatten_object(x.to_tuple(), drop_keys=drop_keys)
if hasattr(x, "shape"):
# A tensor. Nothing to do.
return x
raise TypeError(
f"Unexpected type {type(x)} for x, drop_keys={drop_keys}, "
f"content is {string_type(x, with_shape=True)}"
)
[docs]
def max_diff(
expected: Any,
got: Any,
verbose: int = 0,
level: int = 0,
flatten: bool = False,
debug_info: Optional[List[str]] = None,
begin: int = 0,
end: int = -1,
_index: int = 0,
allow_unique_tensor_with_list_of_one_element: bool = True,
) -> Dict[str, float]:
"""
Returns the maximum discrepancy.
:param expected: expected values
:param got: values
:param verbose: verbosity level
:param level: for embedded outputs, used for debug purpposes
:param flatten: flatten outputs
:param debug_info: debug information
:param begin: first output to considered
:param end: last output to considered (-1 for the last one)
:param _index: used with begin and end
:param allow_unique_tensor_with_list_of_one_element:
allow a comparison between a single tensor and a list of one tensor
:return: dictionary with many values
* abs: max absolute error
* rel: max relative error
* sum: sum of the errors
* n: number of outputs values, if there is one
output, this number will be the number of elements
of this output
* dnan: difference in the number of nan
You may use :func:`string_diff` to display the discrepancies in one string.
"""
if expected is None and got is None:
return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
if allow_unique_tensor_with_list_of_one_element:
if hasattr(expected, "shape") and isinstance(got, (list, tuple)) and len(got) == 1:
return max_diff(
expected,
got[0],
verbose=verbose,
level=level,
flatten=False,
debug_info=debug_info,
allow_unique_tensor_with_list_of_one_element=False,
)
return max_diff(
expected,
got,
verbose=verbose,
level=level,
flatten=flatten,
debug_info=debug_info,
begin=begin,
end=end,
_index=_index,
allow_unique_tensor_with_list_of_one_element=False,
)
if hasattr(expected, "to_tuple"):
if verbose >= 6:
print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}")
return max_diff(
expected.to_tuple(),
got,
verbose=verbose,
level=level + 1,
debug_info=(
[*(debug_info if debug_info else []), f"{' ' * level}to_tupleA"]
if verbose > 5
else None
),
begin=begin,
end=end,
_index=_index,
flatten=flatten,
)
if hasattr(got, "to_tuple"):
if verbose >= 6:
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
return max_diff(
expected,
got.to_tuple(),
verbose=verbose,
level=level + 1,
debug_info=(
[*(debug_info if debug_info else []), f"{' ' * level}to_tupleB"]
if verbose > 5
else None
),
begin=begin,
end=end,
_index=_index,
flatten=flatten,
)
if isinstance(got, (list, tuple)):
if len(got) != 1:
if verbose >= 6:
print(
f"[max_diff] list,tuple,2: {string_type(expected)} "
f"? {string_type(got)}"
)
if verbose > 2:
import torch
print(
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
f"len(got)={len(got)}, level={level}, _index={_index}"
)
for i, (a, b) in enumerate(zip(expected, got)):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
print(
f" i={i} expected {a.dtype}:{a.shape}, "
f"has {b.dtype}:{b.shape}, _index={_index}"
)
else:
print(
f" i={i} a is {type(a)}, "
f"b is {type(b)}, _index={_index}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if verbose >= 6:
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
return max_diff(
expected,
got[0],
verbose=verbose,
level=level + 1,
begin=begin,
end=end,
_index=_index,
debug_info=debug_info,
flatten=flatten,
)
if isinstance(expected, (tuple, list)):
if verbose >= 6:
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
if len(expected) == 1 and not isinstance(got, type(expected)):
if verbose >= 6:
print(f"[max_diff] list,tuple,3: {string_type(expected)} ? {string_type(got)}")
return max_diff(
expected[0],
got,
verbose=verbose,
level=level + 1,
begin=begin,
end=end,
_index=_index,
debug_info=debug_info,
flatten=flatten,
)
if not isinstance(got, (tuple, list)):
if verbose >= 6:
print(f"[max_diff] list,tuple,4: {string_type(expected)} ? {string_type(got)}")
if verbose > 2:
print(
f"[max_diff] inf because type(expected)={type(expected)}, "
f"type(got)={type(got)}, level={level}, _index={_index}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if len(got) != len(expected):
if flatten:
if verbose >= 6:
print(
f"[max_diff] list,tuple,5: {string_type(expected)} "
f"? {string_type(got)}"
)
# Let's flatten.
if verbose > 2:
print(
f"[max_diff] flattening because of length mismatch, "
f"expected is\n {string_type(expected)}\n -- and got is\n "
f"{string_type(got)}"
)
flat_a = flatten_object(expected, drop_keys=True)
flat_b = flatten_object(got, drop_keys=True)
if verbose > 2:
print(
f"[max_diff] after flattening, "
f"expected is\n {string_type(flat_a)}\n -- and got is\n "
f"{string_type(flat_b)}"
)
return max_diff(
flat_a,
flat_b,
verbose=verbose,
level=level,
begin=begin,
end=end,
_index=_index,
debug_info=(
[
*(debug_info if debug_info else []),
(
f"{' ' * level}flatten["
f"{string_type(expected)},{string_type(got)}]"
),
]
if verbose > 5
else None
),
flatten=False,
)
if verbose > 2:
import torch
print(
f"[max_diff] (b) inf because len(expected)={len(expected)}, "
f"len(got)={len(got)}, level={level}, _index={_index}"
)
for i, (a, b) in enumerate(zip(expected, got)):
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
print(
f" i={i} expected {a.dtype}:{a.shape}, "
f"has {b.dtype}:{b.shape}, _index={_index}"
)
else:
print(f" i={i} a is {type(a)}, b is {type(b)}")
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if verbose >= 6:
print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}")
am, rm, sm, n, dn = 0, 0, 0.0, 0.0, 0
for ip, (e, g) in enumerate(zip(expected, got)):
d = max_diff(
e,
g,
verbose=verbose,
level=level + 1,
debug_info=(
[
*(debug_info if debug_info else []),
f"{' ' * level}[{ip}] so far abs {am} - rel {rm}",
]
if verbose > 5
else None
),
begin=begin,
end=end,
_index=_index + ip,
flatten=flatten,
)
am = max(am, d["abs"])
dn = max(dn, d["dnan"])
rm = max(rm, d["rel"])
sm += d["sum"]
n += d["n"]
return dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
if isinstance(expected, dict):
if verbose >= 6:
print(f"[max_diff] dict: {string_type(expected)} ? {string_type(got)}")
assert (
begin == 0 and end == -1
), f"begin={begin}, end={end} not compatible with dictionaries"
if isinstance(got, dict):
if len(expected) != len(got):
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if set(expected) != set(got):
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
keys = sorted(expected)
return max_diff(
[expected[k] for k in keys],
[got[k] for k in keys],
level=level,
flatten=flatten,
debug_info=debug_info,
begin=begin,
end=end,
_index=_index,
verbose=verbose,
)
if not isinstance(got, (tuple, list)):
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if len(expected) != len(got):
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
return max_diff(
list(expected.values()),
got,
level=level,
flatten=flatten,
debug_info=debug_info,
begin=begin,
end=end,
_index=_index,
verbose=verbose,
)
import torch
if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
if isinstance(expected, torch.Tensor):
expected = expected.detach().cpu().numpy()
if isinstance(got, torch.Tensor):
got = got.detach().cpu().numpy()
if verbose >= 6:
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
if _index < begin or (end != -1 and _index >= end):
# out of boundary
return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
if isinstance(expected, (int, float)):
if isinstance(got, np.ndarray) and len(got.shape) == 0:
got = float(got)
if isinstance(got, (int, float)):
if expected == got:
return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
return dict(
abs=abs(expected - got),
rel=abs(expected - got) / (abs(expected) + 1e-5),
sum=abs(expected - got),
n=1,
dnan=0,
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
if expected.dtype in (np.complex64, np.complex128):
if got.dtype == expected.dtype:
got = np.real(got)
elif got.dtype not in (np.float32, np.float64):
if verbose >= 10:
# To understand the value it comes from.
if debug_info:
print("\n".join(debug_info))
print(
f"[max_diff-c] expected.dtype={expected.dtype}, "
f"got.dtype={got.dtype}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
expected = np.real(expected)
if expected.shape != got.shape:
if verbose >= 10:
# To understand the value it comes from.
if debug_info:
print("\n".join(debug_info))
print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}")
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
# nan are replace by 1e10, any discrepancies in that order of magnitude
# is likely caused by nans
exp_cpu = np.nan_to_num(expected.astype(np.float64), nan=1e10)
got_cpu = np.nan_to_num(got.astype(np.float64), nan=1e10)
diff = np.abs(got_cpu - exp_cpu)
ndiff = np.abs(np.isnan(expected).astype(int) - np.isnan(got).astype(int))
rdiff = diff / (np.abs(exp_cpu) + 1e-3)
if diff.size == 0:
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
(0, 0, 0, 0, 0)
if exp_cpu.size == got_cpu.size
else (np.inf, np.inf, np.inf, 0, np.inf)
)
else:
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
float(diff.max()),
float(rdiff.max()),
float(diff.sum()),
float(diff.size),
float(ndiff.sum()),
)
if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
# To understand the value it comes from.
if debug_info:
print("\n".join(debug_info))
print(
f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, "
f"nan_diff={nan_diff}, dtype={expected.dtype}, "
f"shape={expected.shape}, level={level}, _index={_index}"
)
if abs_diff >= 10:
idiff = np.argmax(diff.reshape((-1,)))
x = expected.reshape((-1,))[idiff]
y = got.reshape((-1,))[idiff]
print(
f" [max_diff-2] abs diff={abs_diff}, "
f"x={x}, y={y}, level={level}, "
f"_index={_index}"
)
print(y)
if rel_diff >= 10:
idiff = np.argmax(rdiff.reshape((-1,)))
x = expected.reshape((-1,))[idiff]
y = got.reshape((-1,))[idiff]
print(
f" [max_diff-3] rel diff={rel_diff}, "
f"x={x}, y={y}, level={level}, "
f"_index={_index}"
)
return dict(abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff)
if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
if verbose >= 6:
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
if _index < begin or (end != -1 and _index >= end):
# out of boundary
return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
if expected.dtype in (torch.complex64, torch.complex128):
if got.dtype == expected.dtype:
got = torch.view_as_real(got)
elif got.dtype not in (torch.float32, torch.float64):
if verbose >= 10:
# To understand the value it comes from.
if debug_info:
print("\n".join(debug_info))
print(
f"[max_diff-c] expected.dtype={expected.dtype}, "
f"got.dtype={got.dtype}"
)
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
expected = torch.view_as_real(expected)
if expected.shape != got.shape:
if verbose >= 10:
# To understand the value it comes from.
if debug_info:
print("\n".join(debug_info))
print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}")
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
# nan are replace by 1e10, any discrepancies in that order of magnitude
# is likely caused by nans
exp_cpu = expected.to(torch.float64).cpu().nan_to_num(1e10)
got_cpu = got.to(torch.float64).cpu().nan_to_num(1e10)
diff = (got_cpu - exp_cpu).abs()
ndiff = (expected.isnan().cpu().to(int) - got.isnan().cpu().to(int)).abs()
rdiff = diff / (exp_cpu.abs() + 1e-3)
if diff.numel() > 0:
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
float(diff.max()),
float(rdiff.max()),
float(diff.sum()),
float(diff.numel()),
float(ndiff.sum()),
)
elif got_cpu.numel() == exp_cpu.numel():
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (0.0, 0.0, 0.0, 0.0, 0.0)
else:
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
np.inf,
np.inf,
np.inf,
np.inf,
np.inf,
)
if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
# To understand the value it comes from.
if debug_info:
print("\n".join(debug_info))
print(
f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, "
f"nan_diff={nan_diff}, dtype={expected.dtype}, "
f"shape={expected.shape}, level={level}, _index={_index}"
)
if abs_diff >= 10:
idiff = torch.argmax(diff.reshape((-1,)))
x = expected.reshape((-1,))[idiff]
y = got.reshape((-1,))[idiff]
print(
f" [max_diff-2] abs diff={abs_diff}, "
f"x={x}, y={y}, level={level}, "
f"_index={_index}"
)
print(y)
if rel_diff >= 10:
idiff = torch.argmax(rdiff.reshape((-1,)))
x = expected.reshape((-1,))[idiff]
y = got.reshape((-1,))[idiff]
print(
f" [max_diff-3] rel diff={rel_diff}, "
f"x={x}, y={y}, level={level}, "
f"_index={_index}"
)
return dict(abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff)
if "SquashedNormal" in expected.__class__.__name__:
if verbose >= 6:
print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}")
values = (
expected.mean.detach().to("cpu"),
expected.scale.detach().to("cpu"),
)
return max_diff(
values,
got,
verbose=verbose,
level=level + 1,
begin=begin,
end=end,
_index=_index,
flatten=flatten,
)
if expected.__class__.__name__ == "DynamicCache":
if got.__class__.__name__ == "DynamicCache":
if verbose >= 6:
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
return max_diff(
[expected.key_cache, expected.value_cache],
[got.key_cache, got.value_cache],
verbose=verbose,
)
if isinstance(got, tuple) and len(got) == 2:
return max_diff(
[expected.key_cache, expected.value_cache],
[got[0], got[1]],
verbose=verbose,
)
raise AssertionError(
f"DynamicCache not fully implemented with classes "
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
f"and expected={string_type(expected)}, got={string_type(got)},\n"
f"level={level}"
)
if expected.__class__.__name__ == "EncoderDecoderCache":
if got.__class__.__name__ == "EncoderDecoderCache":
if verbose >= 6:
print(
f"[max_diff] EncoderDecoderCache: "
f"{string_type(expected)} ? {string_type(got)}"
)
return max_diff(
[expected.self_attention_cache, expected.cross_attention_cache],
[got.self_attention_cache, got.cross_attention_cache],
verbose=verbose,
)
if isinstance(got, tuple) and len(got) == 2:
return max_diff(
[expected.self_attention_cache, expected.cross_attention_cache],
[got[0], got[1]],
verbose=verbose,
)
raise AssertionError(
f"EncoderDecoderCache not fully implemented with classes "
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
f"and expected={string_type(expected)}, got={string_type(got)},\n"
f"level={level}"
)
if expected.__class__.__name__ in ("transformers.cache_utils.MambaCache", "MambaCache"):
if verbose >= 6:
print(f"[max_diff] MambaCache: {string_type(expected)} ? {string_type(got)}")
if got.__class__.__name__ != expected.__class__.__name__:
# This case happens with onnx where the outputs are flattened.
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
atts = []
for k in ["conv_states", "ssm_states"]:
if hasattr(expected, k) and not hasattr(got, k):
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
atts.append(k)
return max_diff(
[getattr(expected, k) for k in atts],
[getattr(got, k) for k in atts],
level=level,
flatten=flatten,
debug_info=debug_info,
begin=begin,
end=end,
_index=_index,
verbose=verbose,
)
raise AssertionError(
f"Not implemented with implemented with expected="
f"{string_type(expected)}, got={string_type(got)},\n"
f"level={level}"
)
[docs]
def string_diff(diff: Dict[str, Any]) -> str:
"""Renders discrepancies return by :func:`max_diff` into one string."""
# dict(abs=, rel=, sum=, n=n_diff, dnan=)
if diff.get("dnan", None):
if diff["abs"] == 0 or diff["rel"] == 0:
return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}"
return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}, dnan={diff['dnan']}"
if diff["abs"] == 0 or diff["rel"] == 0:
return f"abs={diff['abs']}, rel={diff['rel']}"
return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}"