Source code for onnx_extended.tools.onnx_io

import base64
import os
import textwrap
from typing import Optional, Set, Union
import onnx


[docs] def load_model( model: Union[str, onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto], external: bool = True, base_dir: Optional[str] = None, ) -> Union[onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto]: """ Loads a model or returns the only argument if the type is already a ModelProto. :param model: proto file :param external: loads the external data as well :param base_dir: needed if external is True and the model has external weights :return: ModelProto """ if isinstance(model, onnx.ModelProto): if base_dir is not None and external: if not os.path.exists(base_dir): raise FileNotFoundError(f"Unable to find folder {base_dir!r}.") onnx.load_external_data_for_model(model, base_dir) return model if isinstance(model, (onnx.GraphProto, onnx.FunctionProto)): return model if not os.path.exists(model): raise FileNotFoundError(f"Unable to find model {model!r}.") with open(model, "rb") as f: return onnx.load(f, load_external_data=external)
[docs] def load_external( model: onnx.ModelProto, base_dir: str, names: Optional[Set[str]] = None ): """ Loads external data into memory. :param model: the model loaded with :func:`load_model` :param base_dir: directory when the data can be found :param names: subsets of names to load or None for all """ from onnx.external_data_helper import ( _get_all_tensors, uses_external_data, load_external_data_for_tensor, ) for tensor in _get_all_tensors(model): if names is not None and tensor.name not in names: continue if uses_external_data(tensor): load_external_data_for_tensor(tensor, base_dir) # After loading raw_data from external_data, change the state of tensors tensor.data_location = onnx.TensorProto.DEFAULT # and remove external data del tensor.external_data[:]
[docs] def save_model( proto: onnx.ModelProto, filename: str, external: bool = False, convert_attribute: bool = True, size_threshold: int = 1024, all_tensors_to_one_file: bool = True, ): """ Saves a model into an onnx file. :param proto: ModelProto :param filename: where to save it :param external: saves weights as external data :param convert_attribute: converts attributes as well :param size_threshold: every weight above that threshold is saved as external :param all_tensors_to_one_file: saves all tensors in one unique file """ if not external: onnx.save_model(proto, filename) return dirname, shortname = os.path.split(filename) onnx.convert_model_to_external_data( proto, all_tensors_to_one_file=all_tensors_to_one_file, location=shortname + ".data", convert_attribute=convert_attribute, size_threshold=size_threshold, ) proto = onnx.write_external_data_tensors(proto, dirname) with open(filename, "wb") as f: f.write(proto.SerializeToString())
[docs] def onnx2string(proto: onnx.ModelProto, as_code: bool = False) -> str: """ Takes a model and returns a string pluggagle in a python script. It uses module `base64`. :param proto: model proto :param as_code: if true, returns the model as a piece of code :return: string """ text = proto.SerializeToString() content = base64.b64encode(text).decode("ascii") if not as_code: return content lines = [] while content: if len(content) > 64: lines.append(content[:64]) content = content[64:] else: lines.append(content) break slines = "\n".join([f' "{li}"' for li in lines]) template = textwrap.dedent( """ from onnx_extended.tools.onnx_io import string2onnx text = ( {model} ) model = string2onnx(text) """ ) return template.format(model=slines)
[docs] def string2onnx(text: str) -> onnx.ModelProto: """ Restores a model saved with :func:`onnx2string <onnx_extended.tools.onnx_io.onnx2string>`. :param text: string produced by :func:`onnx2string <onnx_extended.tools.onnx_io.onnx2string>` :return: ModelProto """ b = base64.b64decode(text.encode("ascii")) model = onnx.ModelProto() model.ParseFromString(b) return model