onnx_diagnostic.ext_test_case

The module contains the main class ExtTestCase which adds specific functionalities to this project.

class onnx_diagnostic.ext_test_case.ExtTestCase(methodName='runTest')[source]

Inherits from unittest.TestCase and adds specific comprison functions and other helpers.

assertAlmostEqual(expected: ndarray, value: ndarray, atol: float = 0, rtol: float = 0)[source]

In the name

assertEmpty(value: Any)[source]

In the name

assertEqual(expected: Any, value: Any, msg: str = '')[source]

Overwrites the error message to get a more explicit message about what is what.

assertEqualArray(expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str | None = None)[source]

In the name

assertEqualArrays(expected: Sequence[ndarray], value: Sequence[ndarray], atol: float = 0, rtol: float = 0, msg: str | None = None)[source]

In the name

assertExists(name)[source]

Checks the existing of a file.

assertGreaterOrEqual(a, b, msg=None)[source]

In the name

assertIn(tofind: str, text: str, msg: str = '')[source]

Just like self.assertTrue(a in b), but with a nicer default message.

assertNotEmpty(value: Any)[source]

In the name

assertRaise(fct: Callable, exc_type: type[Exception], msg: str | None = None)[source]

In the name

assertSetContained(set1, set2)[source]

Checks that set1 is contained in set2.

assertStartsWith(prefix: str, full: str)[source]

In the name

capture(fct: Callable)[source]

Runs a function and capture standard output and error.

Parameters:

fct – function to run

Returns:

result of fct, output, error

dump_onnx(name: str, proto: Any, folder: str | None = None) str[source]

Dumps an onnx file.

get_dump_file(name: str, folder: str | None = None) str[source]

Returns a filename to dump a model.

print_model(model: ModelProto)[source]

Prints a ModelProto

print_onnx(model: ModelProto)[source]

Prints a ModelProto

classmethod setUpClass()[source]

Hook method for setting up class fixture before running tests in the class.

classmethod tearDownClass()[source]

Hook method for deconstructing the class fixture after running all tests in the class.

classmethod todo(f: Callable, msg: str)[source]

Adds a todo printed when all test are run.

tryCall(fct: Callable, msg: str | None = None, none_if: str | None = None) Any | None[source]

Calls the function, catch any error.

Parameters:
  • fct – function to call

  • msg – error message to display if failing

  • none_if – returns None if this substring is found in the error message

Returns:

output of fct

onnx_diagnostic.ext_test_case.dump_dort_onnx(fn)[source]

Context manager to dump onnx model created by dort.

onnx_diagnostic.ext_test_case.get_figure(ax)[source]

Returns the figure of a matplotlib figure.

onnx_diagnostic.ext_test_case.has_cuda() bool[source]

Returns torch.cuda.device_count() > 0.

onnx_diagnostic.ext_test_case.has_executorch(version: str = '', msg: str = '') Callable[source]

Tells if ExecuTorch is installed.

onnx_diagnostic.ext_test_case.has_onnxruntime_training(push_back_batch: bool = False)[source]

Tells if onnxruntime_training is installed.

onnx_diagnostic.ext_test_case.has_onnxscript(version: str, msg: str = '') Callable[source]

Skips a unit test if onnxscript is not recent enough.

onnx_diagnostic.ext_test_case.has_torch(version: str) bool[source]

Returns True if torch transformers is higher.

onnx_diagnostic.ext_test_case.has_transformers(version: str) bool[source]

Returns True if transformers version is higher.

onnx_diagnostic.ext_test_case.hide_stdout(f: Callable | None = None) Callable[source]

Catches warnings, hides standard output. The function may be disabled by setting UNHIDE=1 before running the unit test.

Parameters:

f – the function is called with the stdout as an argument

onnx_diagnostic.ext_test_case.ignore_warnings(warns: List[Warning]) Callable[source]

Catches warnings.

Parameters:

warns – warnings to ignore

onnx_diagnostic.ext_test_case.is_azure() bool[source]

Tells if the job is running on Azure DevOps.

onnx_diagnostic.ext_test_case.long_test(msg: str = '') Callable[source]

Skips a unit test if it runs on azure pipeline on Windows.

onnx_diagnostic.ext_test_case.measure_time(stmt: str | Callable, context: Dict[str, Any] | None = None, repeat: int = 10, number: int = 50, warmup: int = 1, div_by_number: bool = True, max_time: float | None = None) Dict[str, str | int | float][source]

Measures a statement and returns the results as a dictionary.

Parameters:
  • stmt – string or callable

  • context – variable to know in a dictionary

  • repeat – average over repeat experiment

  • number – number of executions in one row

  • warmup – number of iteration to do before starting the real measurement

  • div_by_number – divide by the number of executions

  • max_time – execute the statement until the total goes beyond this time (approximatively), repeat is ignored, div_by_number must be set to True

Returns:

dictionary

<<<

from pprint import pprint
from math import cos
from onnx_diagnostic.ext_test_case import measure_time

res = measure_time(lambda: cos(0.5))
pprint(res)

>>>

    {'average': np.float64(4.369397356640547e-08),
     'context_size': 64,
     'deviation': np.float64(5.784608579894682e-09),
     'max_exec': np.float64(6.05599780101329e-08),
     'min_exec': np.float64(4.007997631561011e-08),
     'number': 50,
     'repeat': 10,
     'ttime': np.float64(4.369397356640547e-07),
     'warmup_time': 1.2524997146101668e-05}

See Timer.repeat for a better understanding of parameter repeat and number. The function returns a duration corresponding to number times the execution of the main statement.

onnx_diagnostic.ext_test_case.never_test(msg: str = '') Callable[source]

Skips a unit test.

onnx_diagnostic.ext_test_case.requires_cuda(msg: str = '', version: str = '', memory: int = 0)[source]

Skips a test if cuda is not available.

Parameters:
  • msg – to overwrite the message

  • version – minimum version

  • memory – minimun number of Gb to run the test

onnx_diagnostic.ext_test_case.requires_diffusers(version: str, msg: str = '', or_older_than: str | None = None) Callable[source]

Skips a unit test if transformers is not recent enough.

onnx_diagnostic.ext_test_case.requires_executorch(version: str, msg: str = '') Callable[source]

Skips a unit test if executorch is not recent enough.

onnx_diagnostic.ext_test_case.requires_monai(version: str = '', msg: str = '') Callable[source]

Skips a unit test if monai is not recent enough.

onnx_diagnostic.ext_test_case.requires_numpy(version: str, msg: str = '') Callable[source]

Skips a unit test if numpy is not recent enough.

onnx_diagnostic.ext_test_case.requires_onnx(version: str, msg: str = '') Callable[source]

Skips a unit test if onnx is not recent enough.

onnx_diagnostic.ext_test_case.requires_onnx_array_api(version: str, msg: str = '') Callable[source]

Skips a unit test if onnx-array-api is not recent enough.

onnx_diagnostic.ext_test_case.requires_onnxruntime(version: str, msg: str = '') Callable[source]

Skips a unit test if onnxruntime is not recent enough.

onnx_diagnostic.ext_test_case.requires_onnxruntime_training(push_back_batch: bool = False, ortmodule: bool = False, msg: str = '') Callable[source]

Skips a unit test if onnxruntime is not onnxruntime_training.

onnx_diagnostic.ext_test_case.requires_onnxscript(version: str, msg: str = '') Callable[source]

Skips a unit test if onnxscript is not recent enough.

onnx_diagnostic.ext_test_case.requires_pyinstrument(version: str = '', msg: str = '') Callable[source]

Skips a unit test if pyinstrument is not recent enough.

onnx_diagnostic.ext_test_case.requires_sklearn(version: str, msg: str = '') Callable[source]

Skips a unit test if scikit-learn is not recent enough.

onnx_diagnostic.ext_test_case.requires_torch(version: str, msg: str = '') Callable[source]

Skips a unit test if pytorch is not recent enough.

onnx_diagnostic.ext_test_case.requires_transformers(version: str, msg: str = '', or_older_than: str | None = None) Callable[source]

Skips a unit test if transformers is not recent enough.

onnx_diagnostic.ext_test_case.requires_vocos(version: str = '', msg: str = '') Callable[source]

Skips a unit test if vocos is not recent enough.

onnx_diagnostic.ext_test_case.requires_zoo(msg: str = '') Callable[source]

Skips a unit test if environment variable ZOO is not equal to 1.

onnx_diagnostic.ext_test_case.skipif_ci_apple(msg) Callable[source]

Skips a unit test if it runs on azure pipeline on Windows.

onnx_diagnostic.ext_test_case.skipif_ci_linux(msg) Callable[source]

Skips a unit test if it runs on azure pipeline on Linux.

onnx_diagnostic.ext_test_case.skipif_ci_windows(msg) Callable[source]

Skips a unit test if it runs on azure pipeline on Windows.

onnx_diagnostic.ext_test_case.skipif_not_onnxrt(msg) Callable[source]

Skips a unit test if it runs on azure pipeline on Windows.

onnx_diagnostic.ext_test_case.skipif_transformers(version_to_skip: str | Set[str], msg: str) Callable[source]

Skips a unit test if transformers has a specific version.

onnx_diagnostic.ext_test_case.statistics_on_file(filename: str) Dict[str, str | int | float][source]

Computes statistics on a file.

<<<

import pprint
from onnx_diagnostic.ext_test_case import statistics_on_file, __file__

pprint.pprint(statistics_on_file(__file__))

>>>

    {'chars': 26692, 'ext': '.py', 'lines': 901}
onnx_diagnostic.ext_test_case.statistics_on_folder(folder: str | List[str], pattern: str = '.*[.]((py|rst))$', aggregation: int = 0) List[Dict[str, str | int | float]][source]

Computes statistics on files in a folder.

Parameters:
  • folder – folder or folders to investigate

  • pattern – file pattern

  • aggregation – show the first subfolders

Returns:

list of dictionaries

<<<

import os
import pprint
from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__

pprint.pprint(statistics_on_folder(os.path.dirname(__file__)))

>>>

    [{'chars': 16173, 'ext': '.py', 'lines': 470, 'name': 'ort_session.py'},
     {'chars': 6017, 'ext': '.py', 'lines': 202, 'name': 'torch_test_helper.py'},
     {'chars': 26692, 'ext': '.py', 'lines': 901, 'name': 'ext_test_case.py'},
     {'chars': 5661, 'ext': '.py', 'lines': 213, 'name': 'onnx_tools.py'},
     {'chars': 1056, 'ext': '.py', 'lines': 26, 'name': 'cache_helpers.py'},
     {'chars': 2878, 'ext': '.py', 'lines': 115, 'name': 'args.py'},
     {'chars': 36792, 'ext': '.py', 'lines': 1342, 'name': 'helpers.py'},
     {'chars': 136, 'ext': '.py', 'lines': 4, 'name': '__init__.py'},
     {'chars': 3635,
      'ext': '.py',
      'lines': 112,
      'name': 'torch_export_patches/onnx_export_serialization.py'},
     {'chars': 101,
      'ext': '.py',
      'lines': 3,
      'name': 'torch_export_patches/__init__.py'},
     {'chars': 14028,
      'ext': '.py',
      'lines': 360,
      'name': 'torch_export_patches/onnx_export_errors.py'},
     {'chars': 6330,
      'ext': '.py',
      'lines': 189,
      'name': 'torch_export_patches/patches/patch_transformers.py'},
     {'chars': 4137,
      'ext': '.py',
      'lines': 127,
      'name': 'torch_export_patches/patches/patch_torch.py'},
     {'chars': 0,
      'ext': '.py',
      'lines': 0,
      'name': 'torch_export_patches/patches/__init__.py'},
     {'chars': 0, 'ext': '.py', 'lines': 0, 'name': 'torch_models/__init__.py'},
     {'chars': 2191, 'ext': '.py', 'lines': 72, 'name': 'torch_models/llms.py'}]

Aggregated:

<<<

import os
import pprint
from onnx_diagnostic.ext_test_case import statistics_on_folder, __file__

pprint.pprint(statistics_on_folder(os.path.dirname(__file__), aggregation=1))

>>>

    [{'chars': 16173,
      'dir': '',
      'ext': '.py',
      'lines': 470,
      'name': 'ort_session.py'},
     {'chars': 6017,
      'dir': '',
      'ext': '.py',
      'lines': 202,
      'name': 'torch_test_helper.py'},
     {'chars': 26692,
      'dir': '',
      'ext': '.py',
      'lines': 901,
      'name': 'ext_test_case.py'},
     {'chars': 5661,
      'dir': '',
      'ext': '.py',
      'lines': 213,
      'name': 'onnx_tools.py'},
     {'chars': 1056,
      'dir': '',
      'ext': '.py',
      'lines': 26,
      'name': 'cache_helpers.py'},
     {'chars': 2878, 'dir': '', 'ext': '.py', 'lines': 115, 'name': 'args.py'},
     {'chars': 36792, 'dir': '', 'ext': '.py', 'lines': 1342, 'name': 'helpers.py'},
     {'chars': 136, 'dir': '', 'ext': '.py', 'lines': 4, 'name': '__init__.py'},
     {'chars': 3635,
      'dir': 'torch_export_patches',
      'ext': '.py',
      'lines': 112,
      'name': 'torch_export_patches/onnx_export_serialization.py'},
     {'chars': 101,
      'dir': 'torch_export_patches',
      'ext': '.py',
      'lines': 3,
      'name': 'torch_export_patches/__init__.py'},
     {'chars': 14028,
      'dir': 'torch_export_patches',
      'ext': '.py',
      'lines': 360,
      'name': 'torch_export_patches/onnx_export_errors.py'},
     {'chars': 6330,
      'dir': 'torch_export_patches',
      'ext': '.py',
      'lines': 189,
      'name': 'torch_export_patches/patches/patch_transformers.py'},
     {'chars': 4137,
      'dir': 'torch_export_patches',
      'ext': '.py',
      'lines': 127,
      'name': 'torch_export_patches/patches/patch_torch.py'},
     {'chars': 0,
      'dir': 'torch_export_patches',
      'ext': '.py',
      'lines': 0,
      'name': 'torch_export_patches/patches/__init__.py'},
     {'chars': 0,
      'dir': 'torch_models',
      'ext': '.py',
      'lines': 0,
      'name': 'torch_models/__init__.py'},
     {'chars': 2191,
      'dir': 'torch_models',
      'ext': '.py',
      'lines': 72,
      'name': 'torch_models/llms.py'}]
onnx_diagnostic.ext_test_case.unit_test_going()[source]

Enables a flag telling the script is running while testing it. Avois unit tests to be very long.

onnx_diagnostic.ext_test_case.with_path_append(path_to_add: str | List[str]) Callable[source]

Adds a path to sys.path to check.