import os
from typing import Dict, Generator, Iterator, Optional, Set, Tuple, Union
import onnx
_rev_type: Dict[int, str] = {
getattr(onnx.TensorProto, k): k
for k in dir(onnx.TensorProto)
if isinstance(getattr(onnx.TensorProto, k), int)
}
[docs]def load_model(
model: Union[str, onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto],
external: bool = True,
base_dir: Optional[str] = None,
) -> Union[onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto]:
"""
Loads a model or returns the only argument if the type
is already a ModelProto.
:param model: proto file
:param external: loads the external data as well
:param base_dir: needed if external is True and
the model has external weights
:return: ModelProto
"""
if isinstance(model, onnx.ModelProto):
if base_dir is not None and external:
if not os.path.exists(base_dir):
raise FileNotFoundError(f"Unable to find folder {base_dir!r}.")
onnx.load_external_data_for_model(model, base_dir)
return model
if isinstance(model, (onnx.GraphProto, onnx.FunctionProto)):
return model
if not os.path.exists(model):
raise FileNotFoundError(f"Unable to find model {model!r}.")
with open(model, "rb") as f:
return onnx.load(f, load_external_data=external)
[docs]def load_external(
model: onnx.ModelProto, base_dir: str, names: Optional[Set[str]] = None
):
"""
Loads external data into memory.
:param model: the model loaded with :func:`load_model`
:param base_dir: directory when the data can be found
:param names: subsets of names to load or None for all
"""
from onnx.external_data_helper import (
_get_all_tensors,
uses_external_data,
load_external_data_for_tensor,
)
for tensor in _get_all_tensors(model):
if names is not None and tensor.name not in names:
continue
if uses_external_data(tensor):
load_external_data_for_tensor(tensor, base_dir)
# After loading raw_data from external_data, change the state of tensors
tensor.data_location = onnx.TensorProto.DEFAULT
# and remove external data
del tensor.external_data[:]
[docs]def enumerate_model_tensors(
model: onnx.ModelProto,
) -> Iterator[Tuple[onnx.TensorProto, bool]]:
"""
Enumerates all tensors in a model.
:param model: model to process
:return: iterator on a couple (TensorProto, bool),
the boolean indicates if the data is external
"""
from onnx.external_data_helper import (
_get_all_tensors,
uses_external_data,
)
for tensor in _get_all_tensors(model):
yield tensor, uses_external_data(tensor)
[docs]def save_model(
proto: onnx.ModelProto,
filename: str,
external: bool = False,
convert_attribute: bool = True,
size_threshold: int = 1024,
all_tensors_to_one_file: bool = True,
):
"""
Saves a model into an onnx file.
:param proto: ModelProto
:param filename: where to save it
:param external: saves weights as external data
:param convert_attribute: converts attributes as well
:param size_threshold: every weight above that threshold is saved as external
:param all_tensors_to_one_file: saves all tensors in one unique file
"""
if not external:
onnx.save_model(proto, filename)
return
dirname, shortname = os.path.split(filename)
onnx.convert_model_to_external_data(
proto,
all_tensors_to_one_file=all_tensors_to_one_file,
location=shortname + ".data",
convert_attribute=convert_attribute,
size_threshold=size_threshold,
)
proto = onnx.write_external_data_tensors(proto, dirname)
with open(filename, "wb") as f:
f.write(proto.SerializeToString())
def _info_type(
typ: Union[onnx.TensorProto, onnx.TypeProto, onnx.SparseTensorProto]
) -> Dict[str, str]:
if typ is None:
return {}
if isinstance(typ, (onnx.TensorProto, onnx.SparseTensorProto)):
shape = [str(i) for i in typ.dims]
return dict(
type="tensor", elem_type=_rev_type[typ.data_type], shape="x".join(shape)
)
if typ.tensor_type:
ret = dict(type="tensor", elem_type=_rev_type[typ.tensor_type.elem_type])
shape = []
for d in typ.tensor_type.shape.dim:
if d.dim_value:
shape.append(str(d.dim_value))
else:
shape.append(d.dim_param or "?")
ret["shape"] = "x".join(shape)
return ret
return dict(kind=str(type(typ)))
[docs]def enumerate_onnx_node_types(
model: Union[str, onnx.ModelProto, onnx.GraphProto],
level: int = 0,
shapes: Optional[Dict[str, onnx.TypeProto]] = None,
external: bool = True,
) -> Generator[Dict[str, Union[str, float]], None, None]:
"""
Looks into types for every node in a model.
:param model: a string or a proto
:param level: level (recursivity level)
:param shapes: known shapes,
returned by :func:onnx.shape_inference.infer_shapes`
:param external: loads the external data if the model is loaded
:return: a list of dictionary which can be turned into a dataframe.
"""
proto = load_model(model, external=external)
if shapes is None and isinstance(proto, onnx.ModelProto):
p2 = onnx.shape_inference.infer_shapes(proto)
values = p2.graph.value_info
shapes = {}
for value in values:
shapes[value.name] = value.type
for o in proto.graph.output:
if o.name not in shapes:
shapes[o.name] = o.type
if isinstance(proto, onnx.ModelProto):
if shapes is None:
raise RuntimeError("shape inference has failed.")
for item in enumerate_onnx_node_types(proto.graph, level=level, shapes=shapes):
yield item
elif isinstance(model, onnx.FunctionProto):
raise NotImplementedError(f"Not implemented for type {type(proto)}.")
else:
for inp in proto.input:
obs = dict(level=level, name=inp.name, kind="input")
obs.update(_info_type(inp.type))
yield obs
for init in proto.initializer:
obs = dict(level=level, name=init.name, kind="initializer")
obs.update(_info_type(init))
yield obs
for init in proto.sparse_initializer:
obs = dict(level=level, name=init.name, kind="sparse_initializer")
obs.update(_info_type(init))
yield obs
for node in proto.node:
obs = dict(
level=level,
name=node.name,
kind="Op",
domain=node.domain,
type=node.op_type,
inputs=",".join(node.input),
outputs=",".join(node.output),
input_types=",".join(
_info_type(shapes.get(i, None)).get("elem_type", "")
for i in node.input
),
output_types=",".join(
_info_type(shapes.get(i, None)).get("elem_type", "")
for i in node.output
),
)
yield obs
for att in node.attribute:
if att.type == onnx.AttributeProto.GRAPH:
obs = dict(name=att.name, kind="attribute", level=level + 1)
yield obs
for item in enumerate_onnx_node_types(
att.g, level=level + 1, shapes=shapes
):
yield item
for out in node.output:
obs = dict(name=out, kind="result", level=level)
obs.update(_info_type(shapes.get(out, None)))
yield obs
for out in proto.output:
obs = dict(level=level, name=out.name, kind="output")
obs.update(_info_type(out.type))
yield obs