Source code for onnx_array_api.ext_test_case

import os
import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from timeit import Timer
from typing import Any, Callable, Dict, List, Optional
import numpy
from numpy.testing import assert_allclose


def is_azure() -> bool:
    "Tells if the job is running on Azure DevOps."
    return os.environ.get("AZURE_HTTP_USER_AGENT", "undefined") != "undefined"


def is_windows() -> bool:
    return sys.platform == "win32"


def is_apple() -> bool:
    return sys.platform == "darwin"


def skipif_ci_windows(msg) -> Callable:
    """
    Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
    """
    if is_windows() and is_azure():
        msg = f"Test does not work on azure pipeline (linux). {msg}"
        return unittest.skip(msg)
    return lambda x: x


def skipif_ci_apple(msg) -> Callable:
    """
    Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`Windows`.
    """
    if is_apple() and is_azure():
        msg = f"Test does not work on azure pipeline (Apple). {msg}"
        return unittest.skip(msg)
    return lambda x: x


[docs] def ignore_warnings(warns: List[Warning]) -> Callable: """ Catches warnings. :param warns: warnings to ignore """ def wrapper(fct): def call_f(self): with warnings.catch_warnings(): warnings.simplefilter("ignore", warns) return fct(self) return call_f return wrapper
def matplotlib_test() -> Callable: """ Decorator for every test checking matplotlib graphs. Cleans matplotlib after its completion. """ def wrapper(fct): import matplotlib def call_f(self): orig_units_registry = matplotlib.units.registry.copy() try: return fct(self) finally: matplotlib.units.registry.clear() matplotlib.units.registry.update(orig_units_registry) matplotlib.pyplot.close("all") return call_f return wrapper
[docs] def example_path(path: str) -> str: """ Fixes a path for the examples. Helps running the example within a unit test. """ if os.path.exists(path): return path this = os.path.abspath(os.path.dirname(__file__)) full = os.path.normpath(os.path.join(this, "..", "_doc", "examples", path)) if os.path.exists(full): return full raise FileNotFoundError(f"Unable to find path {path!r} or {full!r}.")
[docs] def measure_time( stmt: Callable, context: Optional[Dict[str, Any]] = None, repeat: int = 10, number: int = 50, div_by_number: bool = True, max_time: Optional[float] = None, ) -> Dict[str, Any]: """ Measures a statement and returns the results as a dictionary. :param stmt: string :param context: variable to know in a dictionary :param repeat: average over *repeat* experiment :param number: number of executions in one row :param div_by_number: divide by the number of executions :param max_time: execute the statement until the total goes beyond this time (approximatively), *repeat* is ignored, *div_by_number* must be set to True :return: dictionary .. runpython:: :showcode: from onnx_array_api.ext_test_case import measure_time from math import cos res = measure_time(lambda: cos(0.5)) print(res) See `Timer.repeat <https://docs.python.org/3/library/ timeit.html?timeit.Timer.repeat>`_ for a better understanding of parameter *repeat* and *number*. The function returns a duration corresponding to *number* times the execution of the main statement. .. versionchanged:: 0.4 Parameter *max_time* was added. """ if not callable(stmt) and not isinstance(stmt, str): raise TypeError( f"stmt is not callable or a string but is of type {type(stmt)!r}." ) if context is None: context = {} import numpy if isinstance(stmt, str): tim = Timer(stmt, globals=context) else: tim = Timer(stmt) if max_time is not None: if not div_by_number: raise ValueError( "div_by_number must be set to True of max_time is defined." ) i = 1 total_time = 0 results = [] while True: for j in (1, 2): number = i * j time_taken = tim.timeit(number) results.append((number, time_taken)) total_time += time_taken if total_time >= max_time: break if total_time >= max_time: break ratio = (max_time - total_time) / total_time ratio = max(ratio, 1) i = int(i * ratio) res = numpy.array(results) tw = res[:, 0].sum() ttime = res[:, 1].sum() mean = ttime / tw ave = res[:, 1] / res[:, 0] dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5 mes = dict( average=mean, deviation=dev, min_exec=numpy.min(ave), max_exec=numpy.max(ave), repeat=1, number=tw, ttime=ttime, ) else: res = numpy.array(tim.repeat(repeat=repeat, number=number)) if div_by_number: res /= number mean = numpy.mean(res) dev = numpy.mean(res**2) dev = (dev - mean**2) ** 0.5 mes = dict( average=mean, deviation=dev, min_exec=numpy.min(res), max_exec=numpy.max(res), repeat=repeat, number=number, ttime=res.sum(), ) if "values" in context: if hasattr(context["values"], "shape"): mes["size"] = context["values"].shape[0] else: mes["size"] = len(context["values"]) else: mes["context_size"] = sys.getsizeof(context) return mes
[docs] class ExtTestCase(unittest.TestCase): _warns = [] def assertEqualArray( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, ): self.assertEqual(expected.dtype, value.dtype) self.assertEqual(expected.shape, value.shape) assert_allclose(expected, value, atol=atol, rtol=rtol) def assertRaise(self, fct: Callable, exc_type: Exception): try: fct() except exc_type as e: if not isinstance(e, exc_type): raise AssertionError(f"Unexpected exception {type(e)!r}.") return raise AssertionError("No exception was raised.") def assertEmpty(self, value: Any): if not value: return raise AssertionError(f"value is not empty: {value!r}.") def assertExists(self, name): if not os.path.exists(name): raise AssertionError(f"File or folder {name!r} does not exists.") def assertHasAttr(self, cls: type, name: str): if not hasattr(cls, name): raise AssertionError(f"Class {cls} has no attribute {name!r}.") def assertNotEmpty(self, value: Any): if value is None: raise AssertionError(f"value is empty: {value!r}.") if isinstance(value, (list, dict, tuple, set)): if value: raise AssertionError(f"value is empty: {value!r}.") def assertStartsWith(self, prefix: str, full: str): if not full.startswith(prefix): raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
[docs] @classmethod def tearDownClass(cls): for name, line, w in cls._warns: warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}")
[docs] def capture(self, fct: Callable): """ Runs a function and capture standard output and error. :param fct: function to run :return: result of *fct*, output, error """ sout = StringIO() serr = StringIO() with redirect_stdout(sout): with redirect_stderr(serr): res = fct() return res, sout.getvalue(), serr.getvalue()
[docs] def relative_path(self, filename: str, *names: List[str]) -> str: """ Returns a path relative to the folder *filename* is in. The function checks the path existence. :param filename: filename :param names: additional path pieces :return: new path """ dir = os.path.abspath(os.path.dirname(filename)) name = os.path.join(dir, *names) if not os.path.exists(name): raise FileNotFoundError(f"Path {name!r} does not exists.") return name