onnx_diagnostic.ext_test_case¶
The module contains the main class ExtTestCase
which adds
specific functionalities to this project.
- class onnx_diagnostic.ext_test_case.ExtTestCase(methodName='runTest')[source]¶
Inherits from
unittest.TestCase
and adds specific comprison functions and other helpers.- assertAlmostEqual(expected: ndarray, value: ndarray, atol: float = 0, rtol: float = 0)[source]¶
In the name
- assertEqual(expected: Any, value: Any, msg: str = '')[source]¶
Overwrites the error message to get a more explicit message about what is what.
- assertEqualArray(expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str | None = None)[source]¶
In the name
- assertEqualArrays(expected: Sequence[ndarray], value: Sequence[ndarray], atol: float = 0, rtol: float = 0, msg: str | None = None)[source]¶
In the name
- assertIn(tofind: str, text: str, msg: str = '')[source]¶
Just like self.assertTrue(a in b), but with a nicer default message.
- capture(fct: Callable)[source]¶
Runs a function and capture standard output and error.
- Parameters:
fct – function to run
- Returns:
result of fct, output, error
- get_dump_file(name: str, folder: str | None = None) str [source]¶
Returns a filename to dump a model.
- classmethod setUpClass()[source]¶
Hook method for setting up class fixture before running tests in the class.
- classmethod tearDownClass()[source]¶
Hook method for deconstructing the class fixture after running all tests in the class.
- tryCall(fct: Callable, msg: str | None = None, none_if: str | None = None) Any | None [source]¶
Calls the function, catch any error.
- Parameters:
fct – function to call
msg – error message to display if failing
none_if – returns None if this substring is found in the error message
- Returns:
output of fct
- onnx_diagnostic.ext_test_case.dump_dort_onnx(fn)[source]¶
Context manager to dump onnx model created by dort.
- onnx_diagnostic.ext_test_case.has_executorch(version: str = '', msg: str = '') Callable [source]¶
Tells if ExecuTorch is installed.
- onnx_diagnostic.ext_test_case.has_onnxruntime_training(push_back_batch: bool = False)[source]¶
Tells if onnxruntime_training is installed.
- onnx_diagnostic.ext_test_case.has_onnxscript(version: str, msg: str = '') Callable [source]¶
Skips a unit test if onnxscript is not recent enough.
- onnx_diagnostic.ext_test_case.has_torch(version: str) bool [source]¶
Returns True if torch transformers is higher.
- onnx_diagnostic.ext_test_case.has_transformers(version: str) bool [source]¶
Returns True if transformers version is higher.
- onnx_diagnostic.ext_test_case.hide_stdout(f: Callable | None = None) Callable [source]¶
Catches warnings, hides standard output. The function may be disabled by setting
UNHIDE=1
before running the unit test.- Parameters:
f – the function is called with the stdout as an argument
- onnx_diagnostic.ext_test_case.ignore_warnings(warns: List[Warning]) Callable [source]¶
Catches warnings.
- Parameters:
warns – warnings to ignore
- onnx_diagnostic.ext_test_case.is_azure() bool [source]¶
Tells if the job is running on Azure DevOps.
- onnx_diagnostic.ext_test_case.long_test(msg: str = '') Callable [source]¶
Skips a unit test if it runs on azure pipeline on Windows.
- onnx_diagnostic.ext_test_case.measure_time(stmt: str | Callable, context: Dict[str, Any] | None = None, repeat: int = 10, number: int = 50, warmup: int = 1, div_by_number: bool = True, max_time: float | None = None) Dict[str, str | int | float] [source]¶
Measures a statement and returns the results as a dictionary.
- Parameters:
stmt – string or callable
context – variable to know in a dictionary
repeat – average over repeat experiment
number – number of executions in one row
warmup – number of iteration to do before starting the real measurement
div_by_number – divide by the number of executions
max_time – execute the statement until the total goes beyond this time (approximatively), repeat is ignored, div_by_number must be set to True
- Returns:
dictionary
<<<
from pprint import pprint from math import cos from onnx_diagnostic.ext_test_case import measure_time res = measure_time(lambda: cos(0.5)) pprint(res)
>>>
{'average': np.float64(4.369397356640547e-08), 'context_size': 64, 'deviation': np.float64(5.784608579894682e-09), 'max_exec': np.float64(6.05599780101329e-08), 'min_exec': np.float64(4.007997631561011e-08), 'number': 50, 'repeat': 10, 'ttime': np.float64(4.369397356640547e-07), 'warmup_time': 1.2524997146101668e-05}
See Timer.repeat for a better understanding of parameter repeat and number. The function returns a duration corresponding to number times the execution of the main statement.
- onnx_diagnostic.ext_test_case.requires_cuda(msg: str = '', version: str = '', memory: int = 0)[source]¶
Skips a test if cuda is not available.
- Parameters:
msg – to overwrite the message
version – minimum version
memory – minimun number of Gb to run the test
- onnx_diagnostic.ext_test_case.requires_diffusers(version: str, msg: str = '', or_older_than: str | None = None) Callable [source]¶
Skips a unit test if transformers is not recent enough.
- onnx_diagnostic.ext_test_case.requires_executorch(version: str, msg: str = '') Callable [source]¶
Skips a unit test if executorch is not recent enough.
- onnx_diagnostic.ext_test_case.requires_monai(version: str = '', msg: str = '') Callable [source]¶
Skips a unit test if monai is not recent enough.
- onnx_diagnostic.ext_test_case.requires_numpy(version: str, msg: str = '') Callable [source]¶
Skips a unit test if numpy is not recent enough.
- onnx_diagnostic.ext_test_case.requires_onnx(version: str, msg: str = '') Callable [source]¶
Skips a unit test if onnx is not recent enough.
- onnx_diagnostic.ext_test_case.requires_onnx_array_api(version: str, msg: str = '') Callable [source]¶
Skips a unit test if onnx-array-api is not recent enough.
- onnx_diagnostic.ext_test_case.requires_onnxruntime(version: str, msg: str = '') Callable [source]¶
Skips a unit test if onnxruntime is not recent enough.
- onnx_diagnostic.ext_test_case.requires_onnxruntime_training(push_back_batch: bool = False, ortmodule: bool = False, msg: str = '') Callable [source]¶
Skips a unit test if onnxruntime is not onnxruntime_training.
- onnx_diagnostic.ext_test_case.requires_onnxscript(version: str, msg: str = '') Callable [source]¶
Skips a unit test if onnxscript is not recent enough.
- onnx_diagnostic.ext_test_case.requires_pyinstrument(version: str = '', msg: str = '') Callable [source]¶
Skips a unit test if pyinstrument is not recent enough.
- onnx_diagnostic.ext_test_case.requires_sklearn(version: str, msg: str = '') Callable [source]¶
Skips a unit test if scikit-learn is not recent enough.
- onnx_diagnostic.ext_test_case.requires_torch(version: str, msg: str = '') Callable [source]¶
Skips a unit test if pytorch is not recent enough.
- onnx_diagnostic.ext_test_case.requires_transformers(version: str, msg: str = '', or_older_than: str | None = None) Callable [source]¶
Skips a unit test if transformers is not recent enough.
- onnx_diagnostic.ext_test_case.requires_vocos(version: str = '', msg: str = '') Callable [source]¶
Skips a unit test if vocos is not recent enough.
- onnx_diagnostic.ext_test_case.requires_zoo(msg: str = '') Callable [source]¶
Skips a unit test if environment variable ZOO is not equal to 1.
- onnx_diagnostic.ext_test_case.skipif_ci_apple(msg) Callable [source]¶
Skips a unit test if it runs on azure pipeline on Windows.
- onnx_diagnostic.ext_test_case.skipif_ci_linux(msg) Callable [source]¶
Skips a unit test if it runs on azure pipeline on Linux.
- onnx_diagnostic.ext_test_case.skipif_ci_windows(msg) Callable [source]¶
Skips a unit test if it runs on azure pipeline on Windows.
- onnx_diagnostic.ext_test_case.skipif_not_onnxrt(msg) Callable [source]¶
Skips a unit test if it runs on azure pipeline on Windows.
- onnx_diagnostic.ext_test_case.skipif_transformers(version_to_skip: str | Set[str], msg: str) Callable [source]¶
Skips a unit test if transformers has a specific version.
- onnx_diagnostic.ext_test_case.statistics_on_file(filename: str) Dict[str, str | int | float] [source]¶
Computes statistics on a file.
<<<
import pprint from onnx_diagnostic.ext_test_case import statistics_on_file, __file__ pprint.pprint(statistics_on_file(__file__))
>>>
{'chars': 26692, 'ext': '.py', 'lines': 901}
- onnx_diagnostic.ext_test_case.statistics_on_folder(folder: str | List[str], pattern: str = '.*[.]((py|rst))$', aggregation: int = 0) List[Dict[str, str | int | float]] [source]¶
Computes statistics on files in a folder.
- Parameters:
folder – folder or folders to investigate
pattern – file pattern
aggregation – show the first subfolders
- Returns:
list of dictionaries
<<<
import os import pprint from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__ pprint.pprint(statistics_on_folder(os.path.dirname(__file__)))
>>>
[{'chars': 16173, 'ext': '.py', 'lines': 470, 'name': 'ort_session.py'}, {'chars': 6017, 'ext': '.py', 'lines': 202, 'name': 'torch_test_helper.py'}, {'chars': 26692, 'ext': '.py', 'lines': 901, 'name': 'ext_test_case.py'}, {'chars': 5661, 'ext': '.py', 'lines': 213, 'name': 'onnx_tools.py'}, {'chars': 1056, 'ext': '.py', 'lines': 26, 'name': 'cache_helpers.py'}, {'chars': 2878, 'ext': '.py', 'lines': 115, 'name': 'args.py'}, {'chars': 36792, 'ext': '.py', 'lines': 1342, 'name': 'helpers.py'}, {'chars': 136, 'ext': '.py', 'lines': 4, 'name': '__init__.py'}, {'chars': 3635, 'ext': '.py', 'lines': 112, 'name': 'torch_export_patches/onnx_export_serialization.py'}, {'chars': 101, 'ext': '.py', 'lines': 3, 'name': 'torch_export_patches/__init__.py'}, {'chars': 14028, 'ext': '.py', 'lines': 360, 'name': 'torch_export_patches/onnx_export_errors.py'}, {'chars': 6330, 'ext': '.py', 'lines': 189, 'name': 'torch_export_patches/patches/patch_transformers.py'}, {'chars': 4137, 'ext': '.py', 'lines': 127, 'name': 'torch_export_patches/patches/patch_torch.py'}, {'chars': 0, 'ext': '.py', 'lines': 0, 'name': 'torch_export_patches/patches/__init__.py'}, {'chars': 0, 'ext': '.py', 'lines': 0, 'name': 'torch_models/__init__.py'}, {'chars': 2191, 'ext': '.py', 'lines': 72, 'name': 'torch_models/llms.py'}]
Aggregated:
<<<
import os import pprint from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__ pprint.pprint(statistics_on_folder(os.path.dirname(__file__), aggregation=1))
>>>
[{'chars': 16173, 'dir': '', 'ext': '.py', 'lines': 470, 'name': 'ort_session.py'}, {'chars': 6017, 'dir': '', 'ext': '.py', 'lines': 202, 'name': 'torch_test_helper.py'}, {'chars': 26692, 'dir': '', 'ext': '.py', 'lines': 901, 'name': 'ext_test_case.py'}, {'chars': 5661, 'dir': '', 'ext': '.py', 'lines': 213, 'name': 'onnx_tools.py'}, {'chars': 1056, 'dir': '', 'ext': '.py', 'lines': 26, 'name': 'cache_helpers.py'}, {'chars': 2878, 'dir': '', 'ext': '.py', 'lines': 115, 'name': 'args.py'}, {'chars': 36792, 'dir': '', 'ext': '.py', 'lines': 1342, 'name': 'helpers.py'}, {'chars': 136, 'dir': '', 'ext': '.py', 'lines': 4, 'name': '__init__.py'}, {'chars': 3635, 'dir': 'torch_export_patches', 'ext': '.py', 'lines': 112, 'name': 'torch_export_patches/onnx_export_serialization.py'}, {'chars': 101, 'dir': 'torch_export_patches', 'ext': '.py', 'lines': 3, 'name': 'torch_export_patches/__init__.py'}, {'chars': 14028, 'dir': 'torch_export_patches', 'ext': '.py', 'lines': 360, 'name': 'torch_export_patches/onnx_export_errors.py'}, {'chars': 6330, 'dir': 'torch_export_patches', 'ext': '.py', 'lines': 189, 'name': 'torch_export_patches/patches/patch_transformers.py'}, {'chars': 4137, 'dir': 'torch_export_patches', 'ext': '.py', 'lines': 127, 'name': 'torch_export_patches/patches/patch_torch.py'}, {'chars': 0, 'dir': 'torch_export_patches', 'ext': '.py', 'lines': 0, 'name': 'torch_export_patches/patches/__init__.py'}, {'chars': 0, 'dir': 'torch_models', 'ext': '.py', 'lines': 0, 'name': 'torch_models/__init__.py'}, {'chars': 2191, 'dir': 'torch_models', 'ext': '.py', 'lines': 72, 'name': 'torch_models/llms.py'}]