Source code for onnx_extended.reference.c_reference_backend

import os
import re
import unittest
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from unittest.case import SkipTest
import numpy
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import FunctionProto, ModelProto, NodeProto
from onnx.backend.test import BackendTest
from onnx.backend.test.runner import Pattern, TestItem
from onnx.backend.test.loader import load_model_tests
from onnx.backend.base import Backend, Device, DeviceType
from .c_reference_evaluator import CReferenceEvaluator


[docs]class Runner: """ Collects tests and run them as unit tests. :param backend: a subclass of :class:`onnx.backend.base.Backend` :param path_to_test: folder to look at :param kind: subfolder to test :param test_kwargs: additional test parameters """ _add_model_test = BackendTest._add_model_test _add_test = BackendTest._add_test _load_proto = BackendTest._load_proto assert_similar_outputs = BackendTest.assert_similar_outputs def __init__( self, backend: type[Backend], path_to_test: Optional[str] = None, kind: Optional[Union[str, List[str]]] = None, test_kwargs: Optional[dict[str, Any]] = None, ) -> None: self.backend = backend self._include_patterns: Set[Pattern[str]] = set() self._exclude_patterns: Set[Pattern[str]] = set() self._xfail_patterns: Set[Pattern[str]] = set() self._test_kwargs: dict = test_kwargs or {} if path_to_test is None or not os.path.exists(path_to_test): raise FileNotFoundError(f"Unable to find path {path_to_test!r}.") if isinstance(kind, str): kind = [kind] elif kind is None: kind = [ k for k in os.listdir(path_to_test) if os.path.isdir(os.path.join(path_to_test, k)) ] self._test_items: Dict[str, Dict[str, TestItem]] = { f"{k}Model": {} for k in kind } for k in kind: for ot in load_model_tests(path_to_test, kind=k): self._add_model_test(ot, k) def include(self, pattern: str) -> "Runner": self._include_patterns.add(re.compile(pattern)) return self def exclude(self, pattern: str) -> "Runner": self._exclude_patterns.add(re.compile(pattern)) return self def xfail(self, pattern: str) -> "Runner": self._xfail_patterns.add(re.compile(pattern)) return self def _filtered_test_items(self) -> dict[str, dict[str, TestItem]]: filtered: dict[str, dict[str, TestItem]] = {} for category, items_map in self._test_items.items(): filtered[category] = {} for name, item in items_map.items(): if self._include_patterns and ( not any(include.search(name) for include in self._include_patterns) ): item.func = unittest.skip("no matched include pattern")(item.func) for exclude in self._exclude_patterns: if exclude.search(name): item.func = unittest.skip( f"matched exclude pattern '{exclude.pattern}'" )(item.func) for xfail in self._xfail_patterns: if xfail.search(name): item.func = unittest.expectedFailure(item.func) filtered[category][name] = item return filtered
[docs] def tests(self, name: str = "CustomTestCase") -> type[unittest.TestCase]: """ Returns a subclass of `unittest.TestCase`. :param name: name of the subclass """ tests = type("CustomTestCase", (unittest.TestCase,), {}) for items_map in sorted( self._filtered_test_items().values(), key=lambda cl: cl.__class__.__name__ ): for name, item in sorted(items_map.items()): setattr(tests, name, item.func) return tests
[docs] def run( self, verbose: int = 0, exc_cls: Optional[type] = AssertionError ) -> Tuple[ List[Tuple[str, Callable]], List[Tuple[str, Callable, Any]], List[Tuple[str, Callable, Exception]], ]: """ Runs all tests. :param verbose: verbosity, use :epkg:`tqdm` :param exc_cls: exception to raise when a test fails, if None, no exception is raised :return: list of run tests, list of skipped tests, list of failed tests """ tests = self.tests() methods = [] for att in dir(tests): if att.startswith("test_"): test = getattr(tests, att) methods.append((att, test)) if len(methods) == 0: msg = "\n".join(dir(tests)) raise RuntimeError(f"No test was detected. Available tests are:\n{msg}") if verbose: from tqdm import tqdm loop = tqdm(methods) else: loop = methods ran = [] skipped = [] failed = [] for i, (name, f) in enumerate(loop): if verbose: loop.set_description(f"{i+1}/{len(methods)}-{name}") try: f(tests) except SkipTest as es: skipped.append((name, f, es)) continue except Exception as e: if exc_cls is not None: raise exc_cls(f"Test {i}-{name!r} failed.") from e failed.append((name, f, e)) ran.append((name, f)) return ran, skipped, failed
[docs]class CReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep): """ See :class:`onnx_extended.reference.CReferenceEvaluator` for an example. :param session: any runtime with the same interface as :class:`onnx.reference.ReferenceEvaluator` """ def __init__(self, session: CReferenceEvaluator): self._session = session
[docs] def run(self, inputs: List[numpy.ndarray], **kwargs) -> List[numpy.ndarray]: if isinstance(inputs, numpy.ndarray): inputs = [inputs] if isinstance(inputs, list): if len(inputs) == len(self._session.input_names): feeds = dict(zip(self._session.input_names, inputs)) else: feeds = {} pos_inputs = 0 for inp, tshape in zip( self._session.input_names, self._session.input_types ): shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) if shape == inputs[pos_inputs].shape: feeds[inp] = inputs[pos_inputs] pos_inputs += 1 if pos_inputs >= len(inputs): break elif isinstance(inputs, dict): feeds = inputs else: raise TypeError(f"Unexpected input type {type(inputs)!r}.") outs = self._session.run(None, feeds) return outs
[docs]class CReferenceEvaluatorBackend(onnx.backend.base.Backend): """ See :class:`onnx_extended.reference.CReferenceEvaluator` for an example. """ cls_inference = CReferenceEvaluator @classmethod def __class_getitem__(cls, cls_inference: type, name: Optional[str] = None) -> type: """ Creates a new class inheriting from this one but with static attribute `cls_inference` equal to *cls_inference*. The goal is to make it easier to evaluate a runtime sharing the same API as the :class:`CReferenceEvaluator` on CPU. """ if name is None: name = f"{cls.__name__}{cls_inference.__name__}" return type(name, (cls,), {"cls_inference": cls_inference})
[docs] @classmethod def is_opset_supported(cls, model): """ Tells which opsets are supported. """ return True, ""
[docs] @classmethod def supports_device(cls, device: str) -> bool: """ Tells if a specific device is supported. """ d = Device(device) return d.type == DeviceType.CPU
[docs] @classmethod def create_inference_session( cls, model: Union[str, bytes, ModelProto, NodeProto, FunctionProto] ): """ Creates an instance of the class running a model. """ return cls.cls_inference(model)
@classmethod def prepare( cls, model: Any, device: str = "CPU", **kwargs: Dict[str, Any] ) -> CReferenceEvaluatorBackendRep: if isinstance(model, cls.cls_inference): return CReferenceEvaluatorBackendRep(model) if isinstance(model, (str, bytes, ModelProto)): inf = cls.create_inference_session(model) return cls.prepare(inf, device, **kwargs) raise TypeError(f"Unexpected type {type(model)} for model.")
[docs] @classmethod def run_model( cls, model, inputs: List[Any], device: Optional[str] = None, **kwargs: Dict[str, Any], ): """ Called if the onnx proto is a `ModelProto`. """ rep = cls.prepare(model, device or "cpu", **kwargs) return rep.run(inputs, **kwargs)
[docs] @classmethod def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): """ Called if the onnx proto is a `NodeProto`. """ raise NotImplementedError("Unable to run the model node by node.")
[docs]def create_reference_backend( backend: Optional[type[Backend]] = None, path_to_test: Optional[str] = None, kind: Optional[str] = None, ) -> Runner: return Runner( backend or CReferenceEvaluatorBackend, path_to_test=path_to_test, kind=kind, )