"""
The module contains the main class ``ExtTestCase`` which adds
specific functionalities to this project.
"""
import copy
import glob
import itertools
import logging
import os
import re
import shutil
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from timeit import Timer
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy
from numpy.testing import assert_allclose
BOOLEAN_VALUES = (1, "1", True, "True", "true", "TRUE")
[docs]
def is_azure() -> bool:
"""Tells if the job is running on Azure DevOps."""
return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"
def is_windows() -> bool:
return sys.platform == "win32"
def is_apple() -> bool:
return sys.platform == "darwin"
def is_linux() -> bool:
return sys.platform == "linux"
[docs]
def skipif_ci_windows(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
if is_windows() and is_azure():
msg = f"Test does not work on azure pipeline (Windows). {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def skipif_ci_linux(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Linux`."""
if is_linux() and is_azure():
msg = f"Takes too long (Linux). {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def skipif_ci_apple(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
if is_apple() and is_azure():
msg = f"Test does not work on azure pipeline (Apple). {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def unit_test_going() -> bool:
"""
Enables a flag telling the script is running while testing it.
Avois unit tests to be very long.
"""
going = int(os.environ.get("UNITTEST_GOING", 0))
return going == 1
[docs]
def ignore_warnings(warns: List[Warning]) -> Callable:
"""
Catches warnings.
:param warns: warnings to ignore
"""
if not isinstance(warns, (tuple, list)):
warns = (warns,)
new_list = []
for w in warns:
if w == "TracerWarning":
from torch.jit import TracerWarning
new_list.append(TracerWarning)
else:
new_list.append(w)
warns = tuple(new_list)
def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")
def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)
try: # noqa: SIM105
call_f.__name__ = fct.__name__
except AttributeError:
pass
return call_f
return wrapper
[docs]
def ignore_errors(errors: Union[Exception, Tuple[Exception]]) -> Callable:
"""
Catches exception, skip the test if the error is expected sometimes.
:param errors: errors to ignore
"""
def wrapper(fct):
if errors is None:
raise AssertionError(f"errors cannot be None for '{fct}'.")
def call_f(self):
try:
return fct(self)
except errors as e:
raise unittest.SkipTest( # noqa: B904
f"expecting error {e.__class__.__name__}: {e}"
)
try: # noqa: SIM105
call_f.__name__ = fct.__name__
except AttributeError:
pass
return call_f
return wrapper
[docs]
def hide_stdout(f: Optional[Callable] = None) -> Callable:
"""
Catches warnings, hides standard output.
The function may be disabled by setting ``UNHIDE=1``
before running the unit test.
:param f: the function is called with the stdout as an argument
"""
def wrapper(fct):
def call_f(self):
if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
fct(self)
return
st = StringIO()
with redirect_stdout(st), warnings.catch_warnings():
warnings.simplefilter("ignore", (UserWarning, DeprecationWarning))
try:
fct(self)
except AssertionError as e:
if "torch is not recent enough, file" in str(e):
raise unittest.SkipTest(str(e)) # noqa: B904
raise
if f is not None:
f(st.getvalue())
return None
try: # noqa: SIM105
call_f.__name__ = fct.__name__
except AttributeError:
pass
return call_f
return wrapper
[docs]
def long_test(msg: str = "") -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
if os.environ.get("LONGTEST", "0") in ("0", 0, False, "False", "false"):
msg = f"Skipped (set LONGTEST=1 to run it. {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def never_test(msg: str = "") -> Callable:
"""Skips a unit test."""
if os.environ.get("NEVERTEST", "0") in ("0", 0, False, "False", "false"):
msg = f"Skipped (set NEVERTEST=1 to run it. {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def measure_time(
stmt: Union[str, Callable],
context: Optional[Dict[str, Any]] = None,
repeat: int = 10,
number: int = 50,
warmup: int = 1,
div_by_number: bool = True,
max_time: Optional[float] = None,
) -> Dict[str, Union[str, int, float]]:
"""
Measures a statement and returns the results as a dictionary.
:param stmt: string or callable
:param context: variable to know in a dictionary
:param repeat: average over *repeat* experiment
:param number: number of executions in one row
:param warmup: number of iteration to do before starting the
real measurement
:param div_by_number: divide by the number of executions
:param max_time: execute the statement until the total goes
beyond this time (approximately), *repeat* is ignored,
*div_by_number* must be set to True
:return: dictionary
.. runpython::
:showcode:
from pprint import pprint
from math import cos
from onnx_diagnostic.ext_test_case import measure_time
res = measure_time(lambda: cos(0.5))
pprint(res)
See `Timer.repeat <https://docs.python.org/3/library/
timeit.html?timeit.Timer.repeat>`_
for a better understanding of parameter *repeat* and *number*.
The function returns a duration corresponding to
*number* times the execution of the main statement.
"""
if not callable(stmt) and not isinstance(stmt, str):
raise TypeError(f"stmt is not callable or a string but is of type {type(stmt)!r}.")
if context is None:
context = {}
if isinstance(stmt, str):
tim = Timer(stmt, globals=context)
else:
tim = Timer(stmt)
if warmup > 0:
warmup_time = tim.timeit(warmup)
else:
warmup_time = 0
if max_time is not None:
if not div_by_number:
raise ValueError("div_by_number must be set to True of max_time is defined.")
i = 1
total_time = 0.0
results = []
while True:
for j in (1, 2):
number = i * j
time_taken = tim.timeit(number)
results.append((number, time_taken))
total_time += time_taken
if total_time >= max_time:
break
if total_time >= max_time:
break
ratio = (max_time - total_time) / total_time
ratio = max(ratio, 1)
i = int(i * ratio)
res = numpy.array(results)
tw = res[:, 0].sum()
ttime = res[:, 1].sum()
mean = ttime / tw
ave = res[:, 1] / res[:, 0]
dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5
mes = dict(
average=mean,
deviation=dev,
min_exec=numpy.min(ave),
max_exec=numpy.max(ave),
repeat=1,
number=tw,
ttime=ttime,
)
else:
res = numpy.array(tim.repeat(repeat=repeat, number=number))
if div_by_number:
res /= number
mean = numpy.mean(res)
dev = numpy.mean(res**2)
dev = (dev - mean**2) ** 0.5
mes = dict(
average=mean,
deviation=dev,
min_exec=numpy.min(res),
max_exec=numpy.max(res),
repeat=repeat,
number=number,
ttime=res.sum(),
)
if "values" in context:
if hasattr(context["values"], "shape"):
mes["size"] = context["values"].shape[0]
else:
mes["size"] = len(context["values"])
else:
mes["context_size"] = sys.getsizeof(context)
mes["warmup_time"] = warmup_time
return mes
[docs]
def statistics_on_folder(
folder: Union[str, List[str]],
pattern: str = ".*[.]((py|rst))$",
aggregation: int = 0,
) -> List[Dict[str, Union[int, float, str]]]:
"""
Computes statistics on files in a folder.
:param folder: folder or folders to investigate
:param pattern: file pattern
:param aggregation: show the first subfolders
:return: list of dictionaries
.. runpython::
:showcode:
:toggle:
import os
import pprint
from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__
pprint.pprint(statistics_on_folder(os.path.dirname(__file__)))
Aggregated:
.. runpython::
:showcode:
:toggle:
import os
import pprint
from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__
pprint.pprint(statistics_on_folder(os.path.dirname(__file__), aggregation=1))
"""
if isinstance(folder, list):
rows = []
for fold in folder:
last = fold.replace("\\", "/").split("/")[-1]
r = statistics_on_folder(
fold, pattern=pattern, aggregation=max(aggregation - 1, 0)
)
if aggregation == 0:
rows.extend(r)
continue
for line in r:
line["dir"] = os.path.join(last, line["dir"])
rows.extend(r)
return rows
rows = []
reg = re.compile(pattern)
for name in glob.glob("**/*", root_dir=folder, recursive=True):
if not reg.match(name):
continue
if os.path.isdir(os.path.join(folder, name)):
continue
n = name.replace("\\", "/")
spl = n.split("/")
level = len(spl)
stat = statistics_on_file(os.path.join(folder, name))
stat["name"] = name
if aggregation <= 0:
rows.append(stat)
continue
spl = os.path.dirname(name).replace("\\", "/").split("/")
level = "/".join(spl[:aggregation])
stat["dir"] = level
rows.append(stat)
return rows
[docs]
def has_cuda() -> bool:
"""Returns ``torch.cuda.device_count() > 0``."""
import torch
return torch.cuda.device_count() > 0
[docs]
def requires_python(version: Tuple[int, ...], msg: str = ""):
"""
Skips a test if python is too old.
:param msg: to overwrite the message
:param version: minimum version
"""
if sys.version_info[: len(version)] < version:
return unittest.skip(msg or f"python not recent enough {sys.version_info} < {version}")
return lambda x: x
[docs]
def requires_cuda(msg: str = "", version: str = "", memory: int = 0):
"""
Skips a test if cuda is not available.
:param msg: to overwrite the message
:param version: minimum version
:param memory: minimum number of Gb to run the test
"""
import torch
if torch.cuda.device_count() == 0:
msg = msg or "only runs on CUDA but torch does not have it"
return unittest.skip(msg or "cuda not installed")
if version:
import packaging.versions as pv
if pv.Version(torch.version.cuda) < pv.Version(version):
msg = msg or f"CUDA older than {version}"
return unittest.skip(msg or f"cuda not recent enough {torch.version.cuda} < {version}")
if memory:
m = torch.cuda.get_device_properties(0).total_memory / 2**30
if m < memory:
msg = msg or f"available memory is not enough {m} < {memory} (Gb)"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_zoo(msg: str = "") -> Callable:
"""Skips a unit test if environment variable ZOO is not equal to 1."""
var = os.environ.get("ZOO", "0") in BOOLEAN_VALUES
if not var:
msg = f"ZOO not set up or != 1. {msg}"
return unittest.skip(msg or "zoo not installed")
return lambda x: x
[docs]
def requires_sklearn(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`scikit-learn` is not recent enough."""
import packaging.version as pv
import sklearn
if pv.Version(sklearn.__version__) < pv.Version(version):
msg = f"scikit-learn version {sklearn.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_experimental(version: str = "0.0.0", msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`experimental-experiment` is not recent enough."""
import packaging.version as pv
try:
import experimental_experiment
except ImportError:
msg = f"experimental-experiment not installed: {msg}"
return unittest.skip(msg)
if pv.Version(experimental_experiment.__version__) < pv.Version(version):
msg = (
f"experimental-experiment version "
f"{experimental_experiment.__version__} < {version}: {msg}"
)
return unittest.skip(msg)
return lambda x: x
[docs]
def has_torch(version: str) -> bool:
"Returns True if torch transformers is higher."
import packaging.version as pv
import torch
return pv.Version(torch.__version__) >= pv.Version(version)
[docs]
def requires_torch(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`pytorch` is not recent enough."""
import packaging.version as pv
import torch
if pv.Version(torch.__version__) < pv.Version(version):
msg = f"torch version {torch.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_numpy(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`numpy` is not recent enough."""
import packaging.version as pv
import numpy
if pv.Version(numpy.__version__) < pv.Version(version):
msg = f"numpy version {numpy.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_diffusers(
version: str, msg: str = "", or_older_than: Optional[str] = None
) -> Callable:
"""Skips a unit test if :epkg:`transformers` is not recent enough."""
import packaging.version as pv
try:
import diffusers
except ImportError:
msg = f"diffusers not installed {msg}"
return unittest.skip(msg)
v = pv.Version(diffusers.__version__)
if v < pv.Version(version):
msg = f"diffusers version {diffusers.__version__} < {version} {msg}"
return unittest.skip(msg)
if or_older_than and v > pv.Version(or_older_than):
msg = (
f"diffusers version {or_older_than} < "
f"{diffusers.__version__} < {version} {msg}"
)
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_onnxscript(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnxscript` is not recent enough."""
import packaging.version as pv
import onnxscript
if not hasattr(onnxscript, "__version__"):
# development version
return lambda x: x
if pv.Version(onnxscript.__version__) < pv.Version(version):
msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def has_onnxscript(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnxscript` is not recent enough."""
import packaging.version as pv
import onnxscript
if not hasattr(onnxscript, "__version__"):
# development version
return True
if pv.Version(onnxscript.__version__) < pv.Version(version):
msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
return False
return True
[docs]
def requires_onnxruntime(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
import packaging.version as pv
import onnxruntime
if pv.Version(onnxruntime.__version__) < pv.Version(version):
msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def has_onnxruntime_training(push_back_batch: bool = False):
"""Tells if onnxruntime_training is installed."""
try:
from onnxruntime import training
except ImportError:
# onnxruntime not training
training = None
if training is None:
return False
if push_back_batch:
try:
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
except ImportError:
return False
if not hasattr(OrtValueVector, "push_back_batch"):
return False
return True
[docs]
def has_onnxruntime_genai():
"""Tells if onnxruntime_genai is installed."""
try:
import onnxruntime_genai # noqa: F401
return True
except ImportError:
# onnxruntime not training
return False
[docs]
def requires_onnxruntime_training(
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
) -> Callable:
"""Skips a unit test if :epkg:`onnxruntime` is not onnxruntime_training."""
try:
from onnxruntime import training
except ImportError:
# onnxruntime not training
training = None
if training is None:
msg = msg or "onnxruntime_training is not installed"
return unittest.skip(msg)
if push_back_batch:
try:
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
except ImportError:
msg = msg or "OrtValue has no method push_back_batch"
return unittest.skip(msg)
if not hasattr(OrtValueVector, "push_back_batch"):
msg = msg or "OrtValue has no method push_back_batch"
return unittest.skip(msg)
if ortmodule:
try:
import onnxruntime.training.ortmodule # noqa: F401
except (AttributeError, ImportError):
msg = msg or "ortmodule is missing in onnxruntime-training"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_onnx(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnx` is not recent enough."""
import packaging.version as pv
import onnx
if pv.Version(onnx.__version__) < pv.Version(version):
msg = f"onnx version {onnx.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
import packaging.version as pv
import onnx_array_api
if pv.Version(onnx_array_api.__version__) < pv.Version(version):
msg = f"onnx-array-api version {onnx_array_api.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def statistics_on_file(filename: str) -> Dict[str, Union[int, float, str]]:
"""
Computes statistics on a file.
.. runpython::
:showcode:
import pprint
from onnx_diagnostic.ext_test_case import statistics_on_file, __file__
pprint.pprint(statistics_on_file(__file__))
"""
assert os.path.exists(filename), f"File {filename!r} does not exists."
ext = os.path.splitext(filename)[-1]
if ext not in {".py", ".rst", ".md", ".txt"}:
size = os.stat(filename).st_size
return {"size": size}
alpha = set("abcdefghijklmnopqrstuvwxyz0123456789")
with open(filename, "r", encoding="utf-8") as f:
n_line = 0
n_ch = 0
for line in f.readlines():
s = line.strip("\n\r\t ")
if s:
n_ch += len(s.replace(" ", ""))
ch = set(s.lower()) & alpha
if ch:
# It avoid counting line with only a bracket, a comma.
n_line += 1
stat = dict(lines=n_line, chars=n_ch, ext=ext)
if ext != ".py":
return stat
# add statistics on python syntax?
return stat
[docs]
class ExtTestCase(unittest.TestCase):
"""
Inherits from :class:`unittest.TestCase` and adds specific comprison
functions and other helper.
"""
_warns: List[Tuple[str, int, Warning]] = []
_todos: List[Tuple[Callable, str]] = []
[docs]
def unit_test_going(self) -> bool:
"""
Enables a flag telling the script is running while testing it.
Avois unit tests to be very long.
"""
return unit_test_going()
@property
def verbose(self) -> int:
"Returns the the value of environment variable ``VERBOSE``."
return int(os.environ.get("VERBOSE", "0"))
[docs]
@classmethod
def setUpClass(cls):
logger = logging.getLogger("onnxscript.optimizer.constant_folding")
logger.setLevel(logging.ERROR)
unittest.TestCase.setUpClass()
[docs]
@classmethod
def tearDownClass(cls):
for name, line, w in cls._warns:
warnings.warn(f"\n{name}:{line}: {type(w)}\n {w!s}", stacklevel=2)
if not cls._todos:
return
for f, msg in cls._todos:
sys.stderr.write(f"TODO {cls.__name__}::{f.__name__}: {msg}\n")
[docs]
@classmethod
def todo(cls, f: Callable, msg: str):
"Adds a todo printed when all test are run."
cls._todos.append((f, msg))
@classmethod
def ort(cls) -> unittest.__class__:
import onnxruntime
return onnxruntime
@classmethod
def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
from experimental_experiment.torch_interpreter import to_onnx
return to_onnx(*args, **kwargs)
[docs]
def print_model(self, model: "ModelProto"): # noqa: F821
"Prints a ModelProto"
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
print(pretty_onnx(model))
[docs]
def print_onnx(self, model: "ModelProto"): # noqa: F821
"Prints a ModelProto"
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
print(pretty_onnx(model))
[docs]
def get_dump_file(self, name: str, folder: Optional[str] = None) -> str:
"""Returns a filename to dump a model."""
if folder is None:
folder = "dump_test"
if folder and not os.path.exists(folder):
os.mkdir(folder)
return os.path.join(folder, name)
[docs]
def get_dump_folder(self, folder: str) -> str:
"""Returns a folder."""
folder = os.path.join("dump_test", folder)
if not os.path.exists(folder):
os.makedirs(folder)
return folder
[docs]
def clean_dump(self, folder: str = "dump_test"):
"""Cleans this folder."""
for item in os.listdir(folder):
item_path = os.path.join(folder, item)
if os.path.isfile(item_path) or os.path.islink(item_path):
os.remove(item_path)
elif os.path.isdir(item_path):
shutil.rmtree(item_path)
[docs]
def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
"""Dumps an onnx file."""
fullname = self.get_dump_file(name, folder=folder)
with open(fullname, "wb") as f:
f.write(proto.SerializeToString())
return fullname
[docs]
def assertExists(self, name):
"""Checks the existing of a file."""
if not os.path.exists(name):
raise AssertionError(f"File or folder {name!r} does not exists.")
[docs]
def assertGreaterOrEqual(self, a, b, msg=None):
"""In the name"""
if a < b:
return AssertionError(f"{a} < {b}, a not greater or equal than b\n{msg or ''}")
def assertInOr(self, tofind: Tuple[str, ...], text: str, msg: str = ""):
for tof in tofind:
if tof in text:
return
raise AssertionError(
msg or f"Unable to find one string in the list {tofind!r} in\n--\n{text}"
)
[docs]
def assertIn(self, tofind: str, text: str, msg: str = ""):
if tofind in text:
return
raise AssertionError(
msg or f"Unable to find the list of strings {tofind!r} in\n--\n{text}"
)
def assertHasAttr(self, obj: Any, name: str):
assert hasattr(
obj, name
), f"Unable to find attribute {name!r} in object type {type(obj)}"
[docs]
def assertSetContained(self, set1, set2):
"Checks that ``set1`` is contained in ``set2``."
set1 = set(set1)
set2 = set(set2)
if set1 & set2 != set1:
raise AssertionError(f"Set {set2} does not contain set {set1}.")
[docs]
def assertEqualArrays(
self,
expected: Sequence[numpy.ndarray],
value: Sequence[numpy.ndarray],
atol: float = 0,
rtol: float = 0,
msg: Optional[str] = None,
):
"""In the name"""
self.assertEqual(len(expected), len(value))
for a, b in zip(expected, value):
self.assertEqualArray(a, b, atol=atol, rtol=rtol)
[docs]
def assertEqualArray(
self,
expected: Any,
value: Any,
atol: float = 0,
rtol: float = 0,
msg: Optional[str] = None,
):
"""In the name"""
if hasattr(expected, "detach") and hasattr(value, "detach"):
if msg:
try:
self.assertEqual(expected.dtype, value.dtype)
except AssertionError as e:
raise AssertionError(msg) from e
try:
self.assertEqual(expected.shape, value.shape)
except AssertionError as e:
raise AssertionError(msg) from e
else:
self.assertEqual(expected.dtype, value.dtype)
self.assertEqual(expected.shape, value.shape)
import torch
try:
torch.testing.assert_close(value, expected, atol=atol, rtol=rtol)
except AssertionError as e:
expected_max = torch.abs(expected).max()
expected_value = torch.abs(value).max()
rows = [
f"{msg}\n{e}" if msg else str(e),
f"expected max value={expected_max}",
f"expected computed value={expected_value}",
]
raise AssertionError("\n".join(rows)) # noqa: B904
return
from .helpers.torch_helper import to_numpy
if hasattr(expected, "detach"):
expected = to_numpy(expected.detach().cpu())
if hasattr(value, "detach"):
value = to_numpy(value.detach().cpu())
if msg:
try:
self.assertEqual(expected.dtype, value.dtype)
except AssertionError as e:
raise AssertionError(msg) from e
try:
self.assertEqual(expected.shape, value.shape)
except AssertionError as e:
raise AssertionError(msg) from e
else:
self.assertEqual(expected.dtype, value.dtype)
self.assertEqual(expected.shape, value.shape)
try:
assert_allclose(desired=expected, actual=value, atol=atol, rtol=rtol)
except AssertionError as e:
expected_max = numpy.abs(expected).max()
expected_value = numpy.abs(value).max()
te = expected.astype(int) if expected.dtype == numpy.bool_ else expected
tv = value.astype(int) if value.dtype == numpy.bool_ else value
rows = [
f"{msg}\n{e}" if msg else str(e),
f"expected max value={expected_max}",
f"expected computed value={expected_value}\n",
f"ratio={te / tv}\ndiff={te - tv}",
]
raise AssertionError("\n".join(rows)) # noqa: B904
[docs]
def assertEqualDataFrame(self, d1, d2, **kwargs):
"""
Checks that two dataframes are equal.
Calls :func:`pandas.testing.assert_frame_equal`.
"""
from pandas.testing import assert_frame_equal
assert_frame_equal(d1, d2, **kwargs)
def assertEqualTrue(self, value: Any, msg: str = ""):
if value is True:
return
raise AssertionError(msg or f"value is not True: {value!r}")
[docs]
def assertEqual(self, expected: Any, value: Any, msg: str = ""):
"""Overwrites the error message to get a more explicit message about what is what."""
if msg:
super().assertEqual(expected, value, msg)
else:
try:
super().assertEqual(expected, value)
except AssertionError as e:
raise AssertionError( # noqa: B904
f"expected is {expected!r}, value is {value!r}\n{e}"
)
def assertEqualAny(
self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
):
if expected.__class__.__name__ == "BaseModelOutput":
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqual(len(expected), len(value), msg=msg)
self.assertEqual(list(expected), list(value), msg=msg) # checks the order
self.assertEqualAny(
{k: v for k, v in expected.items()}, # noqa: C416
{k: v for k, v in value.items()}, # noqa: C416
atol=atol,
rtol=rtol,
msg=msg,
)
elif isinstance(expected, (tuple, list, dict)):
self.assertIsInstance(value, type(expected), msg=msg)
self.assertEqual(len(expected), len(value), msg=msg)
if isinstance(expected, dict):
for k in expected:
self.assertIn(k, value, msg=msg)
self.assertEqualAny(expected[k], value[k], msg=msg, atol=atol, rtol=rtol)
else:
for e, g in zip(expected, value):
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
elif expected.__class__.__name__ in (
"DynamicCache",
"SlidingWindowCache",
"HybridCache",
):
self.assertEqual(type(expected), type(value), msg=msg)
atts = ["key_cache", "value_cache"]
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
elif expected.__class__.__name__ == "StaticCache":
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqual(expected.max_cache_len, value.max_cache_len)
atts = ["key_cache", "value_cache"]
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
elif expected.__class__.__name__ == "EncoderDecoderCache":
self.assertEqual(type(expected), type(value), msg=msg)
atts = ["self_attention_cache", "cross_attention_cache"]
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
elif isinstance(expected, (int, float, str)):
self.assertEqual(expected, value, msg=msg)
elif hasattr(expected, "shape"):
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
elif expected.__class__.__name__ in ("Dim", "_Dim", "_DimHintType"):
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqual(expected.__name__, value.__name__, msg=msg)
elif expected is None:
self.assertEqual(expected, value, msg=msg)
else:
raise AssertionError(
f"Comparison not implemented for types {type(expected)} and {type(value)}"
)
def assertEqualArrayAny(
self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = ""
):
if isinstance(expected, (tuple, list, dict)):
self.assertIsInstance(value, type(expected), msg=msg)
self.assertEqual(len(expected), len(value), msg=msg)
if isinstance(expected, dict):
for k in expected:
self.assertIn(k, value, msg=msg)
self.assertEqualArrayAny(
expected[k], value[k], msg=msg, atol=atol, rtol=rtol
)
else:
excs = []
for i, (e, g) in enumerate(zip(expected, value)):
try:
self.assertEqualArrayAny(e, g, msg=msg, atol=atol, rtol=rtol)
except AssertionError as e:
excs.append(f"Error at position {i} due to {e}")
if excs:
msg_ = "\n".join(excs)
msg = f"{msg}\n{msg_}" if msg else msg_
raise AssertionError(f"Found {len(excs)} discrepancies\n{msg}")
elif expected.__class__.__name__ in ("DynamicCache", "StaticCache"):
atts = {"key_cache", "value_cache"}
self.assertEqualArrayAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
elif isinstance(expected, (int, float, str)):
self.assertEqual(expected, value, msg=msg)
elif hasattr(expected, "shape"):
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol)
elif expected is None:
assert value is None, f"Expected is None but value is of type {type(value)}"
else:
raise AssertionError(
f"Comparison not implemented for types {type(expected)} and {type(value)}"
)
[docs]
def assertAlmostEqual(
self,
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
):
"""In the name"""
if not isinstance(expected, numpy.ndarray):
expected = numpy.array(expected)
if not isinstance(value, numpy.ndarray):
value = numpy.array(value).astype(expected.dtype)
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
def check_ort(
self, onx: "onnx.ModelProto" # noqa: F821
) -> "onnxruntime.InferenceSession": # noqa: F821
from onnxruntime import InferenceSession
return InferenceSession(
onx if isinstance(onx, str) else onx.SerializeToString(),
providers=["CPUExecutionProvider"],
)
[docs]
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
"""In the name"""
try:
fct()
except exc_type as e:
if not isinstance(e, exc_type):
raise AssertionError(f"Unexpected exception {type(e)!r}.") # noqa: B904
if msg is not None and msg not in str(e):
raise AssertionError(f"Unexpected exception message {e!r}.") # noqa: B904
return
raise AssertionError("No exception was raised.") # noqa: B904
[docs]
def assertEmpty(self, value: Any):
"""In the name"""
if value is None:
return
if not value:
return
raise AssertionError(f"value is not empty: {value!r}.")
[docs]
def assertNotEmpty(self, value: Any):
"""In the name"""
if value is None:
raise AssertionError(f"value is empty: {value!r}.")
if isinstance(value, (list, dict, tuple, set)):
if not value:
raise AssertionError(f"value is empty: {value!r}.")
[docs]
def assertStartsWith(self, prefix: str, full: str):
"""In the name"""
if not full.startswith(prefix):
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
[docs]
def assertEndsWith(self, suffix: str, full: str):
"""In the name"""
if not full.endswith(suffix):
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
[docs]
def capture(self, fct: Callable) -> Tuple[Any, str, str]:
"""
Runs a function and capture standard output and error.
:param fct: function to run
:return: result of *fct*, output, error
"""
sout = StringIO()
serr = StringIO()
with redirect_stdout(sout), redirect_stderr(serr):
try:
res = fct()
except Exception as e:
raise AssertionError(
f"function {fct} failed, stdout="
f"\n{sout.getvalue()}\n---\nstderr=\n{serr.getvalue()}"
) from e
return res, sout.getvalue(), serr.getvalue()
[docs]
def tryCall(
self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None
) -> Optional[Any]:
"""
Calls the function, catch any error.
:param fct: function to call
:param msg: error message to display if failing
:param none_if: returns None if this substring is found in the error message
:return: output of *fct*
"""
try:
return fct()
except Exception as e:
if none_if is not None and none_if in str(e):
return None
if msg is None:
raise
raise AssertionError(msg) from e
[docs]
def assert_onnx_disc(
self,
test_name: str,
proto: "onnx.ModelProto", # noqa: F821
model: "torch.nn.Module", # noqa: F821
inputs: Union[Tuple[Any], Dict[str, Any]],
verbose: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
copy_inputs: bool = True,
expected: Optional[Any] = None,
use_ort: bool = False,
ort_optimized_graph: bool = False,
ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
**kwargs,
):
"""
Checks for discrepancies.
Runs the onnx models, computes expected outputs, in that order.
The inputs may be modified by this functions if the torch model
modifies them inplace.
:param test_name: test name, dumps the model if not empty
:param proto: onnx model
:param model: torch model
:param inputs: inputs
:param verbose: verbosity
:param atol: absolute tolerance
:param rtol: relative tolerance
:param expected: expected values
:param copy_inputs: to copy the inputs
:param use_ort: use :class:`onnxruntime.InferenceSession`
:param ort_optimized_graph: dumps the optimized onnxruntime graph
:param ep: exported program (or saved exported program)
:param kwargs: arguments sent to
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
"""
from .helpers import string_type, string_diff, max_diff
from .helpers.rt_helper import make_feeds
from .helpers.ort_session import InferenceSessionForTorch
kws = dict(with_shape=True, with_min_max=verbose > 1)
vname = test_name or "assert_onnx_disc"
if test_name:
import onnx
name = f"{test_name}.onnx"
if verbose:
print(f"[{vname}] save the onnx model into {name!r}")
if isinstance(proto, str):
name = proto
proto = onnx.load(name)
elif not self.unit_test_going():
assert isinstance(
proto, onnx.ModelProto
), f"Unexpected type {type(proto)} for proto"
name = self.dump_onnx(name, proto)
if verbose and not self.unit_test_going():
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
if verbose:
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
if use_ort:
assert isinstance(
proto, onnx.ModelProto
), f"Unexpected type {type(proto)} for proto"
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
import onnxruntime
options = onnxruntime.SessionOptions()
if ort_optimized_graph:
options.optimized_model_filepath = f"{name}.optort.onnx"
providers = kwargs.get("providers", ["CPUExecutionProvider"])
if verbose:
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
sess = onnxruntime.InferenceSession(
proto.SerializeToString(), options, providers=providers
)
if verbose:
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
got = sess.run(None, feeds)
else:
feeds = make_feeds(proto, inputs, copy=True)
if verbose:
print(f"[{vname}] create InferenceSessionForTorch")
sess = InferenceSessionForTorch(proto, **kwargs)
if verbose:
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
got = sess.run(None, feeds)
if verbose:
print(f"[{vname}] compute expected values")
if expected is None:
if copy_inputs:
expected = (
model(*copy.deepcopy(inputs))
if isinstance(inputs, tuple)
else model(**copy.deepcopy(inputs))
)
else:
expected = model(*inputs) if isinstance(inputs, tuple) else model(**inputs)
if verbose:
print(f"[{vname}] expected {string_type(expected, **kws)}")
print(f"[{vname}] obtained {string_type(got, **kws)}")
if ep:
if isinstance(ep, str):
if verbose:
print(f"[{vname}] load exported program {ep!r}")
import torch
ep = torch.export.load(ep)
ep_inputs = copy.deepcopy(inputs) if copy_inputs else inputs
ep_model = ep.module() # type: ignore[union-attr]
ep_expected = (
ep_model(*copy.deepcopy(ep_inputs))
if isinstance(ep_inputs, tuple)
else ep_model(**copy.deepcopy(ep_inputs))
)
if verbose:
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
ep_diff = max_diff(expected, ep_expected, hist=[0.1, 0.01])
if verbose:
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
assert (
isinstance(ep_diff["abs"], float)
and isinstance(ep_diff["rel"], float)
and not numpy.isnan(ep_diff["abs"])
and ep_diff["abs"] <= atol
and not numpy.isnan(ep_diff["rel"])
and ep_diff["rel"] <= rtol
), (
f"discrepancies in {test_name!r} between the exported program "
f"and the exported model diff={string_diff(ep_diff)}"
)
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
if verbose:
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
diff = max_diff(expected, got, flatten=True, hist=[0.1, 0.01])
if verbose:
print(f"[{vname}] diff {string_diff(diff)}")
assert (
isinstance(diff["abs"], float)
and isinstance(diff["rel"], float)
and not numpy.isnan(diff["abs"])
and diff["abs"] <= atol
and not numpy.isnan(diff["rel"])
and diff["rel"] <= rtol
), (
f"discrepancies in {test_name!r} between the model and "
f"the onnx model diff={string_diff(diff)}"
)
def _debug(self):
"Tells if DEBUG=1 is set up."
return os.environ.get("DEBUG") in BOOLEAN_VALUES
def string_type(self, *args, **kwargs):
from .helpers import string_type
return string_type(*args, **kwargs)
def max_diff(self, *args, **kwargs):
from .helpers import max_diff
return max_diff(*args, **kwargs)
def use_dyn_not_str(self, *args, **kwargs):
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
return use_dyn_not_str(*args, *kwargs)
[docs]
def subloop(self, *args, verbose: int = 0):
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
if len(args) == 1:
for it in args[0]:
with self.subTest(case=it):
if verbose:
print(f"[subloop] it={it!r}")
yield it
else:
for it in itertools.product(*args):
with self.subTest(case=it):
if verbose:
print(f"[subloop] it={it!r}")
yield it