"""
The module contains the main class ``ExtTestCase`` which adds
specific functionalities to this project.
"""
import glob
import importlib
import logging
import os
import re
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from timeit import Timer
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
import numpy
from numpy.testing import assert_allclose
BOOLEAN_VALUES = (1, "1", True, "True", "true", "TRUE")
[docs]
def is_azure() -> bool:
"""Tells if the job is running on Azure DevOps."""
return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"
def is_windows() -> bool:
return sys.platform == "win32"
def is_apple() -> bool:
return sys.platform == "darwin"
def is_linux() -> bool:
return sys.platform == "linux"
[docs]
def skipif_not_onnxrt(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
UNITTEST_ONNXRT = os.environ.get("UNITTEST_ONNXRT", "0")
value = int(UNITTEST_ONNXRT)
if not value:
msg = f"Set UNITTEST_ONNXRT=1 to run the unittest. {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def skipif_ci_windows(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
if is_windows() and is_azure():
msg = f"Test does not work on azure pipeline (Windows). {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def skipif_ci_linux(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Linux`."""
if is_linux() and is_azure():
msg = f"Takes too long (Linux). {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def skipif_ci_apple(msg) -> Callable:
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
if is_apple() and is_azure():
msg = f"Test does not work on azure pipeline (Apple). {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def with_path_append(path_to_add: Union[str, List[str]]) -> Callable:
"""Adds a path to sys.path to check."""
def wraps(f, path_to_add=path_to_add):
def wrapped(self, path_to_add=path_to_add):
cpy = sys.path.copy()
if path_to_add is not None:
if isinstance(path_to_add, str):
path_to_add = [path_to_add]
sys.path.extend(path_to_add)
f(self)
sys.path = cpy
return wrapped
return wraps
[docs]
def unit_test_going():
"""
Enables a flag telling the script is running while testing it.
Avois unit tests to be very long.
"""
going = int(os.environ.get("UNITTEST_GOING", 0))
return going == 1
[docs]
def ignore_warnings(warns: List[Warning]) -> Callable:
"""
Catches warnings.
:param warns: warnings to ignore
"""
if not isinstance(warns, (tuple, list)):
warns = (warns,)
new_list = []
for w in warns:
if w == "TracerWarning":
from torch.jit import TracerWarning
new_list.append(TracerWarning)
else:
new_list.append(w)
warns = tuple(new_list)
def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")
def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)
return call_f
return wrapper
[docs]
def hide_stdout(f: Optional[Callable] = None) -> Callable:
"""
Catches warnings.
:param f: the function is called with the stdout as an argument
"""
def wrapper(fct):
def call_f(self):
st = StringIO()
with redirect_stdout(st), warnings.catch_warnings():
warnings.simplefilter("ignore", (UserWarning, DeprecationWarning))
try:
return fct(self)
except AssertionError as e:
if "torch is not recent enough, file" in str(e):
raise unittest.SkipTest(str(e)) # noqa: B904
raise
if f is not None:
f(st.getvalue())
return call_f
return wrapper
[docs]
def long_test(msg: str = "") -> Callable:
"""
Catches warnings.
:param f: the function is called with the stdout as an argument
"""
"""Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`."""
if os.environ.get("LONGTEST", "0") in ("0", 0, False, "False", "false"):
msg = f"Skipped (set LONGTEST=1 to run it. {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def measure_time(
stmt: Union[str, Callable],
context: Optional[Dict[str, Any]] = None,
repeat: int = 10,
number: int = 50,
warmup: int = 1,
div_by_number: bool = True,
max_time: Optional[float] = None,
) -> Dict[str, Union[str, int, float]]:
"""
Measures a statement and returns the results as a dictionary.
:param stmt: string or callable
:param context: variable to know in a dictionary
:param repeat: average over *repeat* experiment
:param number: number of executions in one row
:param warmup: number of iteration to do before starting the
real measurement
:param div_by_number: divide by the number of executions
:param max_time: execute the statement until the total goes
beyond this time (approximatively), *repeat* is ignored,
*div_by_number* must be set to True
:return: dictionary
.. runpython::
:showcode:
from pprint import pprint
from math import cos
from experimental_experiment.ext_test_case import measure_time
res = measure_time(lambda: cos(0.5))
pprint(res)
See `Timer.repeat <https://docs.python.org/3/library/
timeit.html?timeit.Timer.repeat>`_
for a better understanding of parameter *repeat* and *number*.
The function returns a duration corresponding to
*number* times the execution of the main statement.
"""
if not callable(stmt) and not isinstance(stmt, str):
raise TypeError(f"stmt is not callable or a string but is of type {type(stmt)!r}.")
if context is None:
context = {}
if isinstance(stmt, str):
tim = Timer(stmt, globals=context)
else:
tim = Timer(stmt)
if warmup > 0:
warmup_time = tim.timeit(warmup)
else:
warmup_time = 0
if max_time is not None:
if not div_by_number:
raise ValueError("div_by_number must be set to True of max_time is defined.")
i = 1
total_time = 0.0
results = []
while True:
for j in (1, 2):
number = i * j
time_taken = tim.timeit(number)
results.append((number, time_taken))
total_time += time_taken
if total_time >= max_time:
break
if total_time >= max_time:
break
ratio = (max_time - total_time) / total_time
ratio = max(ratio, 1)
i = int(i * ratio)
res = numpy.array(results)
tw = res[:, 0].sum()
ttime = res[:, 1].sum()
mean = ttime / tw
ave = res[:, 1] / res[:, 0]
dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5
mes = dict(
average=mean,
deviation=dev,
min_exec=numpy.min(ave),
max_exec=numpy.max(ave),
repeat=1,
number=tw,
ttime=ttime,
)
else:
res = numpy.array(tim.repeat(repeat=repeat, number=number))
if div_by_number:
res /= number
mean = numpy.mean(res)
dev = numpy.mean(res**2)
dev = (dev - mean**2) ** 0.5
mes = dict(
average=mean,
deviation=dev,
min_exec=numpy.min(res),
max_exec=numpy.max(res),
repeat=repeat,
number=number,
ttime=res.sum(),
)
if "values" in context:
if hasattr(context["values"], "shape"):
mes["size"] = context["values"].shape[0]
else:
mes["size"] = len(context["values"])
else:
mes["context_size"] = sys.getsizeof(context)
mes["warmup_time"] = warmup_time
return mes
[docs]
class ExtTestCase(unittest.TestCase):
"""
Inherits from :class:`unittest.TestCase` and adds specific comprison
functions and other helpers.
"""
_warns: List[Tuple[str, int, Warning]] = []
[docs]
@classmethod
def setUpClass(cls):
logger = logging.getLogger("onnxscript.optimizer.constant_folding")
logger.setLevel(logging.ERROR)
unittest.TestCase.setUpClass()
[docs]
def print_model(self, model: "ModelProto"): # noqa: F821
"Prints a ModelProto"
from experimental_experiment.helpers import pretty_onnx
print(pretty_onnx(model))
[docs]
def print_onnx(self, model: "ModelProto"): # noqa: F821
"Prints a ModelProto"
from experimental_experiment.helpers import pretty_onnx
print(pretty_onnx(model))
[docs]
def get_dump_file(self, name: str, folder: Optional[str] = None) -> str:
"""Returns a filename to dump a model."""
if folder is None:
folder = "dump_test"
if folder and not os.path.exists(folder):
os.mkdir(folder)
return os.path.join(folder, name)
[docs]
def dump_onnx(
self,
name: str,
proto: Any,
folder: Optional[str] = None,
) -> str:
"""Dumps an onnx file."""
fullname = self.get_dump_file(name, folder=folder)
with open(fullname, "wb") as f:
f.write(proto.SerializeToString())
return fullname
[docs]
def assertExists(self, name):
"""Checks the existing of a file."""
if not os.path.exists(name):
raise AssertionError(f"File or folder {name!r} does not exists.")
[docs]
def assertGreaterOrEqual(self, a, b, msg=None):
"""In the name"""
if a < b:
return AssertionError(f"{a} < {b}, a not greater or equal than b\n{msg or ''}")
def assertInOr(self, tofind: Tuple[str, ...], text: str, msg: str = ""):
for tof in tofind:
if tof in text:
return
raise AssertionError(
msg or f"Unable to find one string in the list {tofind!r} in\n--\n{text}"
)
[docs]
def assertIn(self, tofind: str, text: str, msg: str = ""):
if tofind in text:
return
raise AssertionError(
msg or f"Unable to find the list of strings {tofind!r} in\n--\n{text}"
)
[docs]
def assertEqualArrays(
self,
expected: Sequence[numpy.ndarray],
value: Sequence[numpy.ndarray],
atol: float = 0,
rtol: float = 0,
msg: Optional[str] = None,
):
"""In the name"""
self.assertEqual(len(expected), len(value))
for a, b in zip(expected, value):
self.assertEqualArray(a, b, atol=atol, rtol=rtol)
[docs]
def assertEqualArray(
self,
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
msg: Optional[str] = None,
):
"""In the name"""
if hasattr(expected, "detach") and hasattr(value, "detach"):
if msg:
try:
self.assertEqual(expected.dtype, value.dtype)
except AssertionError as e:
raise AssertionError(msg) from e
try:
self.assertEqual(expected.shape, value.shape)
except AssertionError as e:
raise AssertionError(msg) from e
else:
self.assertEqual(expected.dtype, value.dtype)
self.assertEqual(expected.shape, value.shape)
import torch
try:
torch.testing.assert_close(value, expected, atol=atol, rtol=rtol)
except AssertionError as e:
expected_max = torch.abs(expected).max()
expected_value = torch.abs(value).max()
rows = [
f"{msg}\n{e}" if msg else str(e),
f"expected max value={expected_max}",
f"expected computed value={expected_value}",
]
raise AssertionError("\n".join(rows)) # noqa: B904
return
if hasattr(expected, "detach"):
expected = expected.detach().cpu().numpy()
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
if msg:
try:
self.assertEqual(expected.dtype, value.dtype)
except AssertionError as e:
raise AssertionError(msg) from e
try:
self.assertEqual(expected.shape, value.shape)
except AssertionError as e:
raise AssertionError(msg) from e
else:
self.assertEqual(expected.dtype, value.dtype)
self.assertEqual(expected.shape, value.shape)
try:
assert_allclose(desired=expected, actual=value, atol=atol, rtol=rtol)
except AssertionError as e:
expected_max = numpy.abs(expected).max()
expected_value = numpy.abs(value).max()
rows = [
f"{msg}\n{e}" if msg else str(e),
f"expected max value={expected_max}",
f"expected computed value={expected_value}",
]
raise AssertionError("\n".join(rows)) # noqa: B904
def assertEqualTrue(self, value: Any, msg: str = ""):
if value is True:
return
raise AssertionError(msg or f"value is not True: {value!r}")
[docs]
def assertEqual(self, expected: Any, value: Any, msg: str = ""):
"""Overwrites the error message to get a more explicit message about what is what."""
if msg:
super().assertEqual(expected, value, msg)
else:
try:
super().assertEqual(expected, value)
except AssertionError as e:
raise AssertionError( # noqa: B904
f"expected is {expected!r}, value is {value!r}\n{e}"
)
def assertEqualAny(self, expected: Any, value: Any, msg: str = ""):
if isinstance(expected, (tuple, list, dict)):
self.assertIsInstance(value, type(expected), msg=msg)
self.assertEqual(len(expected), len(value), msg=msg)
if isinstance(expected, dict):
for k in expected:
self.assertIn(k, value, msg=msg)
self.assertEqualAny(expected[k], value[k], msg=msg)
else:
for e, g in zip(expected, value):
self.assertEqualAny(e, g, msg=msg)
elif expected.__class__.__name__ == "DynamicCache":
atts = {"key_cache", "value_cache"}
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
)
elif isinstance(expected, (int, float, str)):
self.assertEqual(expected, value, msg=msg)
elif hasattr(expected, "shape"):
self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqualArray(expected, value, msg=msg)
else:
raise AssertionError(
f"Comparison not implemented for types {type(expected)} and {type(value)}"
)
[docs]
def assertAlmostEqual(
self,
expected: numpy.ndarray,
value: numpy.ndarray,
atol: float = 0,
rtol: float = 0,
):
"""In the name"""
if not isinstance(expected, numpy.ndarray):
expected = numpy.array(expected)
if not isinstance(value, numpy.ndarray):
value = numpy.array(value).astype(expected.dtype)
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
def check_ort(self, onx: "onnx.ModelProto") -> bool: # noqa: F821
from onnxruntime import InferenceSession
return InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
[docs]
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
"""In the name"""
try:
fct()
except exc_type as e:
if not isinstance(e, exc_type):
raise AssertionError(f"Unexpected exception {type(e)!r}.") # noqa: B904
if msg is not None and msg not in str(e):
raise AssertionError(f"Unexpected exception message {e!r}.") # noqa: B904
return
raise AssertionError("No exception was raised.") # noqa: B904
[docs]
def assertEmpty(self, value: Any):
"""In the name"""
if value is None:
return
if not value:
return
raise AssertionError(f"value is not empty: {value!r}.")
[docs]
def assertNotEmpty(self, value: Any):
"""In the name"""
if value is None:
raise AssertionError(f"value is empty: {value!r}.")
if isinstance(value, (list, dict, tuple, set)):
if not value:
raise AssertionError(f"value is empty: {value!r}.")
[docs]
def assertStartsWith(self, prefix: str, full: str):
"""In the name"""
if not full.startswith(prefix):
raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
[docs]
@classmethod
def tearDownClass(cls):
for name, line, w in cls._warns:
warnings.warn(f"\n{name}:{line}: {type(w)}\n {w!s}", stacklevel=2)
[docs]
def capture(self, fct: Callable):
"""
Runs a function and capture standard output and error.
:param fct: function to run
:return: result of *fct*, output, error
"""
sout = StringIO()
serr = StringIO()
with redirect_stdout(sout), redirect_stderr(serr):
try:
res = fct()
except Exception as e:
raise AssertionError(
f"function {fct} failed, stdout="
f"\n{sout.getvalue()}\n---\nstderr=\n{serr.getvalue()}"
) from e
return res, sout.getvalue(), serr.getvalue()
[docs]
def tryCall(
self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None
) -> Optional[Any]:
"""
Calls the function, catch any error.
:param fct: function to call
:param msg: error message to display if failing
:param none_if: returns None if this substring is found in the error message
:return: output of *fct*
"""
try:
return fct()
except Exception as e:
if none_if is not None and none_if in str(e):
return None
if msg is None:
raise
raise AssertionError(msg) from e
[docs]
def dump_dort_onnx(fn):
"""Context manager to dump onnx model created by dort."""
prefix = fn.__name__
folder = "tests_dump"
if not os.path.exists(folder):
os.mkdir(folder)
def wrapped(self):
value = os.environ.get("ONNXRT_DUMP_PATH", None)
os.environ["ONNXRT_DUMP_PATH"] = os.path.join(folder, f"{prefix}_")
res = fn(self)
os.environ["ONNXRT_DUMP_PATH"] = value or ""
return res
return wrapped
[docs]
def has_cuda() -> bool:
"""Returns ``torch.cuda.is_available()``."""
import torch
return torch.cuda.is_available()
[docs]
def requires_cuda(msg: str = "", version: str = "", memory: int = 0):
"""
Skips a test if cuda is not available.
:param msg: to overwrite the message
:param version: minimum version
:param memory: minimun number of Gb to run the test
"""
import torch
if not torch.cuda.is_available():
msg = msg or "only runs on CUDA but torch does not have it"
return unittest.skip(msg or "cuda not installed")
if version:
import packaging.versions as pv
if pv.Version(torch.version.cuda) < pv.Version(version):
msg = msg or f"CUDA older than {version}"
return unittest.skip(msg or f"cuda not recent enough {torch.version.cuda} < {version}")
if memory:
m = torch.cuda.get_device_properties(0).total_memory / 2**30
if m < memory:
msg = msg or f"available memory is not enough {m} < {memory} (Gb)"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_zoo(msg: str = "") -> Callable:
"""Skips a unit test if environment variable ZOO is not equal to 1."""
var = os.environ.get("ZOO", "0") in BOOLEAN_VALUES
if not var:
msg = f"ZOO not set up or != 1. {msg}"
return unittest.skip(msg or "zoo not installed")
return lambda x: x
[docs]
def has_executorch(version: str = "", msg: str = "") -> Callable:
"""Tells if :epkg:`ExecuTorch` is installed."""
if not version:
return importlib.util.find_spec("executorch")
import packaging.version as pv
import executorch
return pv.Version(".".join(executorch.__version__.split(".")[:2])) < pv.Version(version)
[docs]
def requires_sklearn(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`scikit-learn` is not recent enough."""
import packaging.version as pv
import sklearn
if pv.Version(".".join(sklearn.__version__.split(".")[:2])) < pv.Version(version):
msg = f"scikit-learn version {sklearn.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def has_torch(version: str) -> bool:
"Returns True if torch verions is higher."
import packaging.version as pv
import torch
return pv.Version(".".join(torch.__version__.split(".")[:2])) >= pv.Version(version)
[docs]
def requires_torch(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`pytorch` is not recent enough."""
import packaging.version as pv
import torch
if pv.Version(".".join(torch.__version__.split(".")[:2])) < pv.Version(version):
msg = f"torch version {torch.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_executorch(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`executorch` is not recent enough."""
if not has_executorch():
msg = f"executorch is not installed: {msg}"
return unittest.skip(msg)
import packaging.version as pv
import executorch
if hasattr(executorch, "__version__") and pv.Version(
".".join(executorch.__version__.split(".")[:2])
) < pv.Version(version):
msg = f"torch version {executorch.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_monai(version: str = "", msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`monai` is not recent enough."""
import packaging.version as pv
try:
import monai
except ImportError:
return unittest.skip(msg or "monai is not installed")
if version and pv.Version(".".join(monai.__version__.split(".")[:2])) < pv.Version(
version
):
return unittest.skip(f"monai version {monai.__version__} < {version}: {msg}")
return lambda x: x
[docs]
def requires_vocos(version: str = "", msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`vocos` is not recent enough."""
import packaging.version as pv
try:
import vocos
except ImportError:
return unittest.skip(msg or "vocos not installed")
if version and pv.Version(".".join(vocos.__version__.split(".")[:2])) < pv.Version(
version
):
msg = f"vocos version {vocos.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_pyinstrument(version: str = "", msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`pyinstrument` is not recent enough."""
import packaging.version as pv
try:
import pyinstrument
except ImportError:
return unittest.skip(msg or "pyinstrument is not installed")
if version and pv.Version(".".join(pyinstrument.__version__.split(".")[:2])) < pv.Version(
version
):
msg = f"torch version {pyinstrument.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_numpy(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`numpy` is not recent enough."""
import packaging.version as pv
import numpy
if pv.Version(".".join(numpy.__version__.split(".")[:2])) < pv.Version(version):
msg = f"numpy version {numpy.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def require_diffusers(
version: str, msg: str = "", or_older_than: Optional[str] = None
) -> Callable:
"""Skips a unit test if :epkg:`transformers` is not recent enough."""
import packaging.version as pv
try:
import diffusers
except ImportError:
msg = f"diffusers not installed {msg}"
return unittest.skip(msg)
v = pv.Version(".".join(diffusers.__version__.split(".")[:2]))
if v < pv.Version(version):
msg = f"diffusers version {diffusers.__version__} < {version} {msg}"
return unittest.skip(msg)
if or_older_than and v > pv.Version(or_older_than):
msg = (
f"diffusers version {or_older_than} < "
f"{diffusers.__version__} < {version} {msg}"
)
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_onnxscript(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnxscript` is not recent enough."""
import packaging.version as pv
import onnxscript
if not hasattr(onnxscript, "__version__"):
# development version
return lambda x: x
if pv.Version(".".join(onnxscript.__version__.split(".")[:2])) < pv.Version(version):
msg = f"onnxscript version {onnxscript.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_onnxruntime(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
import packaging.version as pv
import onnxruntime
if pv.Version(".".join(onnxruntime.__version__.split(".")[:2])) < pv.Version(version):
msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def has_onnxruntime_training(push_back_batch: bool = False):
"""Tells if onnxruntime_training is installed."""
try:
from onnxruntime import training
except ImportError:
# onnxruntime not training
training = None
if training is None:
return False
if push_back_batch:
try:
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
except ImportError:
return False
if not hasattr(OrtValueVector, "push_back_batch"):
return False
return True
[docs]
def requires_onnxruntime_training(
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
) -> Callable:
"""Skips a unit test if :epkg:`onnxruntime` is not onnxruntime_training."""
try:
from onnxruntime import training
except ImportError:
# onnxruntime not training
training = None
if training is None:
msg = msg or "onnxruntime_training is not installed"
return unittest.skip(msg)
if push_back_batch:
try:
from onnxruntime.capi.onnxruntime_pybind11_state import OrtValueVector
except ImportError:
msg = msg or "OrtValue has no method push_back_batch"
return unittest.skip(msg)
if not hasattr(OrtValueVector, "push_back_batch"):
msg = msg or "OrtValue has no method push_back_batch"
return unittest.skip(msg)
if ortmodule:
try:
import onnxruntime.training.ortmodule # noqa: F401
except (AttributeError, ImportError):
msg = msg or "ortmodule is missing in onnxruntime-training"
return unittest.skip(msg)
return lambda x: x
[docs]
def requires_onnx(version: str, msg: str = "") -> Callable:
"""Skips a unit test if :epkg:`onnx` is not recent enough."""
import packaging.version as pv
import onnx
if pv.Version(".".join(onnx.__version__.split(".")[:2])) < pv.Version(version):
msg = f"onnx version {onnx.__version__} < {version}: {msg}"
return unittest.skip(msg)
return lambda x: x
[docs]
def statistics_on_file(filename: str) -> Dict[str, Union[int, float, str]]:
"""
Computes statistics on a file.
.. runpython::
:showcode:
import pprint
from experimental_experiment.ext_test_case import statistics_on_file, __file__
pprint.pprint(statistics_on_file(__file__))
"""
assert os.path.exists(filename), f"File {filename!r} does not exists."
ext = os.path.splitext(filename)[-1]
if ext not in {".py", ".rst", ".md", ".txt"}:
size = os.stat(filename).st_size
return {"size": size}
alpha = set("abcdefghijklmnopqrstuvwxyz0123456789")
with open(filename, "r", encoding="utf-8") as f:
n_line = 0
n_ch = 0
for line in f.readlines():
s = line.strip("\n\r\t ")
if s:
n_ch += len(s.replace(" ", ""))
ch = set(s.lower()) & alpha
if ch:
# It avoid counting line with only a bracket, a comma.
n_line += 1
stat = dict(lines=n_line, chars=n_ch, ext=ext)
if ext != ".py":
return stat
# add statistics on python syntax?
return stat
[docs]
def statistics_on_folder(
folder: Union[str, List[str]],
pattern: str = ".*[.]((py|rst))$",
aggregation: int = 0,
) -> List[Dict[str, Union[int, float, str]]]:
"""
Computes statistics on files in a folder.
:param folder: folder or folders to investigate
:param pattern: file pattern
:param aggregation: show the first subfolders
:return: list of dictionaries
.. runpython::
:showcode:
import os
import pprint
from experimental_experiment.ext_test_case import statistics_on_folder, __file__
pprint.pprint(statistics_on_folder(os.path.dirname(__file__)))
"""
if isinstance(folder, list):
rows = []
for fo in folder:
last = fo.replace("\\", "/").split("/")[-1]
r = statistics_on_folder(fo, pattern=pattern, aggregation=max(aggregation - 1, 0))
if aggregation == 0:
rows.extend(r)
continue
for line in r:
line["dir"] = os.path.join(last, line["dir"])
rows.extend(r)
return rows
rows = []
reg = re.compile(pattern)
for name in glob.glob("**/*", root_dir=folder, recursive=True):
if not reg.match(name):
continue
if os.path.isdir(os.path.join(folder, name)):
continue
n = name.replace("\\", "/")
spl = n.split("/")
level = len(spl)
stat = statistics_on_file(os.path.join(folder, name))
stat["name"] = name
if aggregation <= 0:
rows.append(stat)
continue
spl = os.path.dirname(name).replace("\\", "/").split("/")
level = "/".join(spl[:aggregation])
stat["dir"] = level
rows.append(stat)
return rows