Source code for onnx_extended.ext_test_case

import os
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from timeit import Timer
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy
from numpy.testing import assert_allclose


def is_azure() -> bool:
    "Tells if the job is running on Azure DevOps."
    return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"


def is_windows() -> bool:
    return sys.platform == "win32"


def is_apple() -> bool:
    return sys.platform == "darwin"


def skipif_ci_windows(msg) -> Callable:
    """
    Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
    """
    if is_windows() and is_azure():
        msg = f"Test does not work on azure pipeline (Windows). {msg}"
        return unittest.skip(msg)
    return lambda x: x


def skipif_ci_apple(msg) -> Callable:
    """
    Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
    """
    if is_apple() and is_azure():
        msg = f"Test does not work on azure pipeline (Apple). {msg}"
        return unittest.skip(msg)
    return lambda x: x


def skipif_unstable(msg) -> Callable:
    """
    Skips a unit test if the environment variable `SKIP_UNSTABLE` is set to 1.
    """
    value = os.environ.get("SKIP_UNSTABLE", "0")
    if value in ("1", 1):
        msg = f"Test is unstable. Disabling it. {msg}"
        return unittest.skip(msg)
    return lambda x: x


def unit_test_going():
    """
    Enables a flag telling the script is running while testing it.
    Avois unit tests to be very long.
    """
    going = int(os.environ.get("UNITTEST_GOING", 0))
    return going == 1


[docs] def ignore_warnings(warns: List[Warning]) -> Callable: """ Catches warnings. :param warns: warnings to ignore """ def wrapper(fct): assert warns is not None, f"warns cannot be None for '{fct}'." def call_f(self): with warnings.catch_warnings(): warnings.simplefilter("ignore", warns) return fct(self) return call_f return wrapper
[docs] def measure_time( stmt: Union[str, Callable], context: Optional[Dict[str, Any]] = None, repeat: int = 10, number: int = 50, warmup: int = 1, div_by_number: bool = True, max_time: Optional[float] = None, ) -> Dict[str, Union[str, int, float]]: """ Measures a statement and returns the results as a dictionary. :param stmt: string or callable :param context: variable to know in a dictionary :param repeat: average over *repeat* experiment :param number: number of executions in one row :param warmup: number of iteration to do before starting the real measurement :param div_by_number: divide by the number of executions :param max_time: execute the statement until the total goes beyond this time (approximatively), *repeat* is ignored, *div_by_number* must be set to True :return: dictionary .. runpython:: :showcode: from pprint import pprint from math import cos from onnx_extended.ext_test_case import measure_time res = measure_time(lambda: cos(0.5)) pprint(res) See `Timer.repeat <https://docs.python.org/3/library/ timeit.html?timeit.Timer.repeat>`_ for a better understanding of parameter *repeat* and *number*. The function returns a duration corresponding to *number* times the execution of the main statement. """ assert callable(stmt) or isinstance( stmt, str ), f"stmt is not callable or a string but is of type {type(stmt)!r}." if context is None: context = {} if isinstance(stmt, str): tim = Timer(stmt, globals=context) else: tim = Timer(stmt) if warmup > 0: warmup_time = tim.timeit(warmup) else: warmup_time = 0 if max_time is not None: assert ( div_by_number ), "div_by_number must be set to True of max_time is defined." i = 1 total_time = 0.0 results = [] while True: for j in (1, 2): number = i * j time_taken = tim.timeit(number) results.append((number, time_taken)) total_time += time_taken if total_time >= max_time: break if total_time >= max_time: break ratio = (max_time - total_time) / total_time ratio = max(ratio, 1) i = int(i * ratio) res = numpy.array(results) tw = res[:, 0].sum() ttime = res[:, 1].sum() mean = ttime / tw ave = res[:, 1] / res[:, 0] dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5 mes = dict( average=mean, deviation=dev, min_exec=numpy.min(ave), max_exec=numpy.max(ave), repeat=1, number=tw, ttime=ttime, ) else: res = numpy.array(tim.repeat(repeat=repeat, number=number)) if div_by_number: res /= number mean = numpy.mean(res) dev = numpy.mean(res**2) dev = (dev - mean**2) ** 0.5 mes = dict( average=mean, deviation=dev, min_exec=numpy.min(res), max_exec=numpy.max(res), repeat=repeat, number=number, ttime=res.sum(), ) if "values" in context: if hasattr(context["values"], "shape"): mes["size"] = context["values"].shape[0] else: mes["size"] = len(context["values"]) else: mes["context_size"] = sys.getsizeof(context) mes["warmup_time"] = warmup_time return mes
[docs] class ExtTestCase(unittest.TestCase): _warns: List[Tuple[str, int, Warning]] = [] def assertExists(self, name): assert os.path.exists(name), f"File or folder {name!r} does not exists."
[docs] def assertIns(self, sub: Tuple[Any, ...], s: str): """ Checks that one of the substrings in sub is part of s. """ for t in sub: if t in s: return raise AssertionError(f"None of the substring in {sub} is part of {s!r}.")
def assertEqualArray( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, msg: Optional[str] = None, ): self.assertEqual(expected.dtype, value.dtype) self.assertEqual(expected.shape, value.shape) if msg: try: assert_allclose(expected, value, atol=atol, rtol=rtol) except AssertionError as e: raise AssertionError(msg) from e else: assert_allclose(expected, value, atol=atol, rtol=rtol)
[docs] def assertAlmostEqual( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, ): if not isinstance(expected, numpy.ndarray): expected = numpy.array(expected) if not isinstance(value, numpy.ndarray): value = numpy.array(value).astype(expected.dtype) self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
[docs] def assertNotAlmostEqual( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, ): if not isinstance(expected, numpy.ndarray): expected = numpy.array(expected) if not isinstance(value, numpy.ndarray): value = numpy.array(value).astype(expected.dtype) try: self.assertEqualArray(expected, value, atol=atol, rtol=rtol) raise AssertionError("Arrays are equal.") except AssertionError: pass
def assertRaise(self, fct: Callable, exc_type: type[Exception]): try: fct() except exc_type as e: if not isinstance(e, exc_type): raise AssertionError(f"Unexpected exception {type(e)!r}.") return raise AssertionError("No exception was raised.") def assertEmpty(self, value: Any): if value is None: return if not value: return raise AssertionError(f"value is not empty: {value!r}.") def assertNotEmpty(self, value: Any): if value is None: raise AssertionError(f"value is empty: {value!r}.") if isinstance(value, (list, dict, tuple, set)): if not value: raise AssertionError(f"value is empty: {value!r}.") def assertStartsWith(self, prefix: str, full: str): if not full.startswith(prefix): raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
[docs] @classmethod def tearDownClass(cls): for name, line, w in cls._warns: warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}")
[docs] def capture(self, fct: Callable) -> Tuple[Any, str, str]: """ Runs a function and capture standard output and error. :param fct: function to run :return: result of *fct*, output, error """ sout = StringIO() serr = StringIO() with redirect_stdout(sout): with redirect_stderr(serr): try: res = fct() except Exception as e: raise AssertionError( f"function {fct} failed, stdout=" f"\n{sout.getvalue()}\n---\nstderr=\n{serr.getvalue()}" ) from e return res, sout.getvalue(), serr.getvalue()
[docs] def tryCall( self, fct: Callable, msg: Optional[str] = None, none_if: Optional[str] = None ) -> Optional[Any]: """ Calls the function, catch any error. :param fct: function to call :param msg: error message to display if failing :param none_if: returns None if this substring is found in the error message :return: output of *fct* """ try: return fct() except Exception as e: if none_if is not None and none_if in str(e): return None if msg is None: raise e raise AssertionError(msg) from e
@classmethod def to_str(cls, onx: "onnx.ModelProto") -> str: # noqa: F821 if hasattr(onx, "SerializeToString"): from onnx_array_api.plotting.text_plot import onnx_simple_text_plot return onnx_simple_text_plot(onx) raise RuntimeError(f"Unable to print type {type(onx)}")