Source code for onnx_extended.tools.onnx_tools

import os
from typing import Dict, Generator, Iterator, Optional, Set, Tuple, Union
import onnx

_rev_type: Dict[int, str] = {
    getattr(onnx.TensorProto, k): k
    for k in dir(onnx.TensorProto)
    if isinstance(getattr(onnx.TensorProto, k), int)
}


[docs]def load_model( model: Union[str, onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto], external: bool = True, base_dir: Optional[str] = None, ) -> Union[onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto]: """ Loads a model or returns the only argument if the type is already a ModelProto. :param model: proto file :param external: loads the external data as well :param base_dir: needed if external is True and the model has external weights :return: ModelProto """ if isinstance(model, onnx.ModelProto): if base_dir is not None and external: if not os.path.exists(base_dir): raise FileNotFoundError(f"Unable to find folder {base_dir!r}.") onnx.load_external_data_for_model(model, base_dir) return model if isinstance(model, (onnx.GraphProto, onnx.FunctionProto)): return model if not os.path.exists(model): raise FileNotFoundError(f"Unable to find model {model!r}.") with open(model, "rb") as f: return onnx.load(f, load_external_data=external)
[docs]def load_external( model: onnx.ModelProto, base_dir: str, names: Optional[Set[str]] = None ): """ Loads external data into memory. :param model: the model loaded with :func:`load_model` :param base_dir: directory when the data can be found :param names: subsets of names to load or None for all """ from onnx.external_data_helper import ( _get_all_tensors, uses_external_data, load_external_data_for_tensor, ) for tensor in _get_all_tensors(model): if names is not None and tensor.name not in names: continue if uses_external_data(tensor): load_external_data_for_tensor(tensor, base_dir) # After loading raw_data from external_data, change the state of tensors tensor.data_location = onnx.TensorProto.DEFAULT # and remove external data del tensor.external_data[:]
[docs]def enumerate_model_tensors( model: onnx.ModelProto, ) -> Iterator[Tuple[onnx.TensorProto, bool]]: """ Enumerates all tensors in a model. :param model: model to process :return: iterator on a couple (TensorProto, bool), the boolean indicates if the data is external """ from onnx.external_data_helper import ( _get_all_tensors, uses_external_data, ) for tensor in _get_all_tensors(model): yield tensor, uses_external_data(tensor)
[docs]def save_model( proto: onnx.ModelProto, filename: str, external: bool = False, convert_attribute: bool = True, size_threshold: int = 1024, all_tensors_to_one_file: bool = True, ): """ Saves a model into an onnx file. :param proto: ModelProto :param filename: where to save it :param external: saves weights as external data :param convert_attribute: converts attributes as well :param size_threshold: every weight above that threshold is saved as external :param all_tensors_to_one_file: saves all tensors in one unique file """ if not external: onnx.save_model(proto, filename) return dirname, shortname = os.path.split(filename) onnx.convert_model_to_external_data( proto, all_tensors_to_one_file=all_tensors_to_one_file, location=shortname + ".data", convert_attribute=convert_attribute, size_threshold=size_threshold, ) proto = onnx.write_external_data_tensors(proto, dirname) with open(filename, "wb") as f: f.write(proto.SerializeToString())
def _info_type( typ: Union[onnx.TensorProto, onnx.TypeProto, onnx.SparseTensorProto] ) -> Dict[str, str]: if typ is None: return {} if isinstance(typ, (onnx.TensorProto, onnx.SparseTensorProto)): shape = [str(i) for i in typ.dims] return dict( type="tensor", elem_type=_rev_type[typ.data_type], shape="x".join(shape) ) if typ.tensor_type: ret = dict(type="tensor", elem_type=_rev_type[typ.tensor_type.elem_type]) shape = [] for d in typ.tensor_type.shape.dim: if d.dim_value: shape.append(str(d.dim_value)) else: shape.append(d.dim_param or "?") ret["shape"] = "x".join(shape) return ret return dict(kind=str(type(typ)))
[docs]def enumerate_onnx_node_types( model: Union[str, onnx.ModelProto, onnx.GraphProto], level: int = 0, shapes: Optional[Dict[str, onnx.TypeProto]] = None, external: bool = True, ) -> Generator[Dict[str, Union[str, float]], None, None]: """ Looks into types for every node in a model. :param model: a string or a proto :param level: level (recursivity level) :param shapes: known shapes, returned by :func:onnx.shape_inference.infer_shapes` :param external: loads the external data if the model is loaded :return: a list of dictionary which can be turned into a dataframe. """ proto = load_model(model, external=external) if shapes is None and isinstance(proto, onnx.ModelProto): p2 = onnx.shape_inference.infer_shapes(proto) values = p2.graph.value_info shapes = {} for value in values: shapes[value.name] = value.type for o in proto.graph.output: if o.name not in shapes: shapes[o.name] = o.type if isinstance(proto, onnx.ModelProto): if shapes is None: raise RuntimeError("shape inference has failed.") for item in enumerate_onnx_node_types(proto.graph, level=level, shapes=shapes): yield item elif isinstance(model, onnx.FunctionProto): raise NotImplementedError(f"Not implemented for type {type(proto)}.") else: for inp in proto.input: obs = dict(level=level, name=inp.name, kind="input") obs.update(_info_type(inp.type)) yield obs for init in proto.initializer: obs = dict(level=level, name=init.name, kind="initializer") obs.update(_info_type(init)) yield obs for init in proto.sparse_initializer: obs = dict(level=level, name=init.name, kind="sparse_initializer") obs.update(_info_type(init)) yield obs for node in proto.node: obs = dict( level=level, name=node.name, kind="Op", domain=node.domain, type=node.op_type, inputs=",".join(node.input), outputs=",".join(node.output), input_types=",".join( _info_type(shapes.get(i, None)).get("elem_type", "") for i in node.input ), output_types=",".join( _info_type(shapes.get(i, None)).get("elem_type", "") for i in node.output ), ) yield obs for att in node.attribute: if att.type == onnx.AttributeProto.GRAPH: obs = dict(name=att.name, kind="attribute", level=level + 1) yield obs for item in enumerate_onnx_node_types( att.g, level=level + 1, shapes=shapes ): yield item for out in node.output: obs = dict(name=out, kind="result", level=level) obs.update(_info_type(shapes.get(out, None))) yield obs for out in proto.output: obs = dict(level=level, name=out.name, kind="output") obs.update(_info_type(out.type)) yield obs