Source code for onnx_array_api.ort.ort_optimizers

from typing import Union, Optional
from onnx import ModelProto, load
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi._pybind_state import GraphOptimizationLevel
from ..cache import get_cache_file


[docs]def ort_optimized_model( onx: Union[str, ModelProto], level: str = "ORT_ENABLE_ALL", output: Optional[str] = None, ) -> Union[str, ModelProto]: """ Returns the optimized model used by onnxruntime before running computing the inference. :param onx: ModelProto :param level: optimization level, `'ORT_ENABLE_BASIC'`, `'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'` :param output: output file if the proposed cache is not wanted :return: optimized model """ glevel = getattr(GraphOptimizationLevel, level, None) if glevel is None: raise ValueError( f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}." ) if output is not None: cache = output else: cache = get_cache_file("ort_optimized_model.onnx", remove=True) so = SessionOptions() so.graph_optimization_level = glevel so.optimized_model_filepath = str(cache) InferenceSession( onx if isinstance(onx, str) else onx.SerializeToString(), so, providers=["CPUExecutionProvider"], ) if output is None and not cache.exists(): raise RuntimeError(f"The optimized model {str(cache)!r} not found.") if output is not None: return output if isinstance(onx, str): return str(cache) opt_onx = load(str(cache)) cache.unlink() return opt_onx