Source code for onnx_extended.onnx2.io_helper

import os
from pathlib import Path
from typing import Optional
from . import ModelProto, ParseOptions, SerializeOptions


[docs] def save( proto: ModelProto, f: str | Path, format: str = "protobuf", *, save_as_external_data: bool = False, all_tensors_to_one_file: bool = True, location: str | None = None, size_threshold: int = 1024, convert_attribute: bool = False, ) -> None: """ Saves the ModelProto to the specified path and optionally, serializes tensors with raw data as external data before saving. :param proto: should be a in-memory ModelProto :param f: can be a file-like object (has "write" function) or a string containing a file name or a pathlike object :param format: The serialization format. When it is not specified, it is inferred from the file extension when ``f`` is a path. If not specified _and_ ``f`` is not a path, 'protobuf' is used. The encoding is assumed to be "utf-8" when the format is a text format. :param save_as_external_data: If true, save tensors to external file(s). all_tensors_to_one_file: Effective only if save_as_external_data is True. If true, save all tensors to one external file specified by location. If false, save each tensor to a file named with the tensor name. :param all_tensors_to_one_file: if `save_as_external_data` is True, then saves all tensors into one file instead of a file per tensor :param location: Effective only if `save_as_external_data` is true. Specify the external file that all tensors to save to. Path is relative to the model path. If not specified, will use the model name. :param size_threshold: Effective only if save_as_external_data is True. Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0. :param convert_attribute: Effective only if save_as_external_data is True. If true, convert all tensors to external data If false, convert only non-attribute tensors to external data """ assert isinstance(proto, ModelProto), f"Unexpected type {type(proto)} for proto." assert isinstance(f, (str, Path)), f"Unexpected type {type(f)} for f." assert format == "protobuf", f"Unsupported format={format!r}" assert ( all_tensors_to_one_file ), f"all_tensors_to_one_file={all_tensors_to_one_file} is not implemented" if save_as_external_data or location: opts = SerializeOptions() opts.raw_data_threshold = size_threshold proto.SerializeToFile(str(f), opts, str(location)) else: proto.SerializeToFile(str(f))
[docs] def load( f: str | Path, skip_raw_data: bool = False, raw_data_threshold: int = 1024, load_external_data: Optional[bool] = None, parallel: bool = False, num_threads: int = -1, location: str = "", ) -> ModelProto: """ Loads a serialized ModelProto into memory. :param f: path or bytes :param skip_raw_data: skips the raw data of every tensor, this can be used to load only the architecture of the model even if the model is stored in one unique file :param raw_data_threshold: if `skip_raw_data` is True, still keeps the tensors smaller than this size (in bytes) :param load_external_data: Whether to load the external data. Set to True if the data is under the same directory of the model. :param parallel: parallelize the loading of the tensors :param num_threads: number of threads to use, -1 means the number of cores :param location: location of the external weights (can be different from the value stored in the main model), it must be specified if `load_external_data` is True :return: Loaded in-memory ModelProto. """ assert isinstance(f, (str, bytes, Path)), f"Unexpected type {type(f)} for f." if load_external_data is None: load_external_data = bool(location) assert ( not load_external_data or location ), f"'external_data_file' must be specified if load_external_data={location}" assert ( not location or load_external_data ), f"'load_external_data' must be True if location={location!r}" if isinstance(f, Path): f = str(f) assert not isinstance(f, str) or os.path.splitext(f)[-1] in { ".onnx" }, f"File name must have the extension .onnx to be loaded but f={f!r}" model = ModelProto() if skip_raw_data or parallel: opts = ParseOptions() opts.skip_raw_data = skip_raw_data opts.raw_data_threshold = raw_data_threshold opts.parallel = parallel opts.num_threads = num_threads if isinstance(f, bytes): model.ParseFromString(f, opts) elif location: model.ParseFromFile(f, opts, external_data_file=location) else: model.ParseFromFile(f, opts) else: if isinstance(f, bytes): model.ParseFromString(f) elif location: model.ParseFromFile(f, external_data_file=location) else: model.ParseFromFile(f) return model