import logging
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import numpy as np
from onnx import (
AttributeProto,
FunctionProto,
GraphProto,
ModelProto,
NodeProto,
SparseTensorProto,
TensorProto,
TypeProto,
ValueInfoProto,
)
from onnx.compose import merge_models
from onnx.helper import (
make_attribute,
make_node,
make_graph,
make_function,
make_model,
make_opsetid,
make_tensor,
make_tensor_value_info,
np_dtype_to_tensor_dtype,
set_model_props,
)
from onnx.numpy_helper import from_array, to_array
from onnx.shape_inference import infer_shapes as onnx_infer_shapes
from onnx.version_converter import convert_version
from .onnx_io import load_model
logger = logging.getLogger("onnx-extended")
_rev_type: Dict[int, str] = {
getattr(TensorProto, k): k
for k in dir(TensorProto)
if isinstance(getattr(TensorProto, k), int)
}
def _make_att_graph(name: str, new_body: GraphProto) -> AttributeProto:
attr = AttributeProto()
attr.name = name
attr.g.CopyFrom(new_body)
attr.type = AttributeProto.GRAPH
return attr
def _make_node(
op_type: str,
inputs: List[str],
outputs: List[str],
name: Optional[str] = None,
doc_string: Optional[str] = None,
domain: str = "",
attributes: Optional[Dict[str, Any]] = None,
) -> NodeProto:
"""
Constructs a NodeProto.
:param op_type: (string): The name of the operator to construct
:param inputs: list of input names
:param outputs: list of output names
:param name: optional unique identifier for NodeProto
:param doc_string: optional documentation
string for NodeProto
:param domain: optional domain for NodeProto.
If it's None, we will just use default domain (which is empty)
:param attributes: the attributes of the node. The acceptable values
are documented in `make_attribute`.
:return: node
"""
node = NodeProto()
node.op_type = op_type
node.input.extend(inputs)
node.output.extend(outputs)
if name:
node.name = name
if doc_string:
node.doc_string = doc_string
if domain is not None:
node.domain = domain
if isinstance(attributes, dict):
if len(attributes) > 0:
node.attribute.extend(
make_attribute(key, value) for key, value in sorted(attributes.items())
)
elif attributes:
for att in attributes:
node.attribute.extend([att])
return node
def _apply_optimisation_on_graph(
fct: Callable,
onnx_model: Union[ModelProto, GraphProto, FunctionProto],
recursive: bool = True,
debug_info: Optional[List[str]] = None,
**kwargs: Dict[str, Any],
) -> Union[ModelProto, GraphProto, FunctionProto]:
"""
Applies an optimisation function *fct* on a graph
and not on the model.
:param fct: function to optimize like :func:`onnx_remove_node_unused`
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param debug_info: debug information (private)
:param kwargs: additional parameters
:return: new onnx _model
"""
if hasattr(onnx_model, "graph"):
if debug_info is None:
debug_info = []
graph = fct(onnx_model.graph, debug_info=debug_info + ["GRAPH"], **kwargs)
functions = [
fct(f, debug_info=debug_info + [f"FUNCTION {f.name}"], **kwargs)
for f in onnx_model.functions
]
if hasattr(onnx_model, "value_info"):
graph.value_info.extend(onnx_model.value_info)
new_model = make_model(
graph,
opset_imports=[
make_opsetid(d.domain, d.version) for d in onnx_model.opset_import
],
functions=functions,
)
new_model.ir_version = onnx_model.ir_version
new_model.producer_name = onnx_model.producer_name
new_model.producer_version = onnx_model.producer_version
new_model.domain = onnx_model.domain
new_model.model_version = onnx_model.model_version
new_model.doc_string = onnx_model.doc_string
return new_model
raise TypeError(
f"This function only works on 'ModelProto' anod not not on {type(onnx_model)}."
)
def _apply_remove_node_fct_node(
fct: Callable, node: NodeProto, recursive: bool, debug_info: str
) -> NodeProto:
"""
Applies an optimizing function on a subgraphs.
:param node: onnx node
:param recursive: does it in subgraphs as well
:return: new node
"""
if not hasattr(node, "attribute"):
return node
modified = 0
new_atts = []
for att in node.attribute:
if att.name in ("body", "then_branch", "else_branch"):
new_body = fct(
att.g, recursive=recursive, debug_info=debug_info + [att.name]
)
new_atts.append(_make_att_graph(att.name, new_body))
modified += 1
else:
new_atts.append(att)
if modified > 0:
new_node = _make_node(
node.op_type, node.input, node.output, name=node.name, attributes=new_atts
)
return new_node
return node
def _process_node(
node: NodeProto,
data: Dict,
edges: Dict,
paths: Dict,
prefix: str = "",
sep: str = ":X:",
path: Optional[List[str]] = None,
):
node_name = prefix + node.name
data[node_name, 1] = node
path = [] if path is None else path.copy()
paths[node_name, 1] = path
path = path.copy()
path.append(node_name)
for inp in node.input:
data[inp, 0] = node
edges[(inp, 0), (node_name, 1)] = node
paths[inp, 0] = path
if sep in node_name:
# We need to link an input to the parent node
# if the node is part of subgraph.
# path_r = paths[inp, 0]
if len(path) <= 1:
raise RuntimeError(
f"Unexpected path {path!r}, this may happen "
f"if sep={sep!r} is already used in the original model."
)
edges[(inp, 0), (path[-2], 1)] = node
for out in node.output:
data[out, 0] = node
paths[out, 0] = node_name
edges[(node_name, 1), (out, 0)] = node
if len(node.attribute) > 0:
for att in node.attribute:
if not hasattr(att, "g"):
continue
if not isinstance(att.g, GraphProto):
continue
for no in att.g.node:
_process_node(no, data, edges, paths, prefix=node_name + sep, path=path)
[docs]def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options):
"""
Removes unused nodes of the graph. An unused node
is not involved in the output computation.
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param debug_info: debug information (private)
:param options: unused
:return: new onnx _model
"""
if debug_info is None:
debug_info = [str(type(onnx_model)).rsplit(".", maxsplit=1)[-1].strip("'>")]
else:
debug_info = debug_info + [
str(type(onnx_model)).rsplit(".", maxsplit=1)[-1].strip("'>")
]
if hasattr(onnx_model, "graph"):
return _apply_optimisation_on_graph(
onnx_remove_node_unused,
onnx_model,
recursive=recursive,
debug_info=debug_info,
**options,
)
graph = onnx_model
logger.debug("onnx_remove_node_unused:begin with %d nodes.", len(graph.node))
is_function = isinstance(graph, FunctionProto)
data = {}
valid = {}
edges = {}
paths = {}
if not is_function:
for init in graph.initializer:
data[init.name, 0] = init
for node in graph.node:
_process_node(node, data, edges, paths)
for out in graph.output:
valid[out if is_function else out.name, 0] = True
modif = 1
while modif > 0:
modif = 0
for e1, e2 in edges: # pylint: disable=E1141
if valid.get(e2, False) and not valid.get(e1, False):
valid[e1] = True
modif += 1
new_nodes = [n for n in graph.node if (n.name, 1) in valid]
if not is_function:
new_inits = [n for n in graph.initializer if (n.name, 0) in valid]
if recursive:
# Handles subgraphs.
for i in range(len(new_nodes)):
node = new_nodes[i]
if node is None or not (node.attribute):
continue
new_nodes[i] = _apply_remove_node_fct_node(
onnx_remove_node_unused,
node,
recursive=True,
debug_info=debug_info + [node.name],
)
# Finally create the new graph.
nodes = list(filter(lambda n: n is not None, new_nodes))
if is_function:
logger.debug("onnx_remove_node_unused:end function with %d nodes.", len(nodes))
return make_function(
onnx_model.domain,
onnx_model.name,
onnx_model.input,
onnx_model.output,
nodes,
opset_imports=onnx_model.opset_import,
attributes=onnx_model.attribute,
doc_string=onnx_model.doc_string,
)
graph = make_graph(
nodes, onnx_model.name, onnx_model.input, onnx_model.output, new_inits
)
graph.value_info.extend(onnx_model.value_info)
logger.debug("onnx_remove_node_unused:end graph with %d nodes.", len(nodes))
return graph
def _guess_proto_dtype(dtype) -> int:
return np_dtype_to_tensor_dtype(dtype)
def get_tensor_shape(
obj: Union[ValueInfoProto, TypeProto, TensorProto]
) -> Optional[List[Union[int, str, None]]]:
"""
Returns the shape if that makes sense for this object.
"""
if isinstance(obj, ValueInfoProto):
return get_tensor_shape(obj.type)
elif not isinstance(obj, TypeProto):
raise TypeError(f"Unexpected type {type(obj)!r}.")
if not obj.tensor_type.HasField("shape"):
return None
shape = []
for d in obj.tensor_type.shape.dim:
v = d.dim_value if d.dim_value > 0 else d.dim_param
shape.append(v)
if len(shape) == 0:
return shape
return list(None if s in (0, "") else s for s in shape)
def enumerate_model_node_outputs(
model: ModelProto, add_node: bool = False, order: bool = False
) -> Iterable:
"""
Enumerates all the nodes of a model.
:param model: :epkg:`ONNX` graph
:param add_node: if False, the function enumerates
all output names from every node, otherwise, it
enumerates tuple (output name, node)
:param order: goes through outputs following the graph order
:return: enumerator
"""
if not hasattr(model, "graph"):
raise TypeError(f"Parameter model is not an ONNX model but {type(model)}")
if order:
edges = []
dorder = {}
node_names = {}
for inp in model.graph.input:
dorder[0, inp.name] = 0
for node in model.graph.node:
dorder[1, node.name] = 0
for i in node.input:
edges.append(("in", i, node.name))
for o in node.output:
edges.append(("out", o, node.name))
node_names[o] = node
dorder[0, o] = 0
modif = 1
n_iter = 0
while modif > 0 and n_iter <= len(model.graph.node):
modif = 0
n_iter += 1
for kind, data_name, node_name in edges:
if kind == "in":
if (0, data_name) not in dorder:
continue
if dorder[0, data_name] + 1 > dorder[1, node_name]:
modif += 1
dorder[1, node_name] = dorder[0, data_name] + 1
else:
if dorder[1, node_name] + 1 > dorder[0, data_name]:
modif += 1
dorder[0, data_name] = dorder[1, node_name] + 1
orders = [(v, k) for k, v in dorder.items()]
orders.sort()
for _, k in orders:
if k[0] == 1:
continue
out = k[1]
if out not in node_names:
continue
yield (out, node_names[out]) if add_node else out
else:
for node in model.graph.node:
for out in node.output:
yield (out, node) if add_node else out
def _info_type(typ: Union[TensorProto, TypeProto, SparseTensorProto]) -> Dict[str, str]:
if typ is None:
return {}
if isinstance(typ, (TensorProto, SparseTensorProto)):
shape = [str(i) for i in typ.dims]
return dict(
type="tensor", elem_type=_rev_type[typ.data_type], shape="x".join(shape)
)
if typ.tensor_type:
ret = dict(type="tensor", elem_type=_rev_type[typ.tensor_type.elem_type])
shape = []
for d in typ.tensor_type.shape.dim:
if d.dim_value:
shape.append(str(d.dim_value))
else:
shape.append(d.dim_param or "?")
ret["shape"] = "x".join(shape)
return ret
return dict(kind=str(type(typ)))
[docs]def enumerate_onnx_node_types(
model: Union[str, ModelProto, GraphProto],
level: int = 0,
shapes: Optional[Dict[str, TypeProto]] = None,
external: bool = True,
) -> Generator[Dict[str, Union[str, float]], None, None]:
"""
Looks into types for every node in a model.
:param model: a string or a proto
:param level: level (recursivity level)
:param shapes: known shapes,
returned by :func:onnx.shape_inference.infer_shapes`
:param external: loads the external data if the model is loaded
:return: a list of dictionary which can be turned into a dataframe.
"""
proto = load_model(model, external=external)
if shapes is None and isinstance(proto, ModelProto):
p2 = onnx_infer_shapes(proto)
values = p2.graph.value_info
shapes = {}
for value in values:
shapes[value.name] = value.type
for o in proto.graph.output:
if o.name not in shapes:
shapes[o.name] = o.type
if isinstance(proto, ModelProto):
if shapes is None:
raise RuntimeError("shape inference has failed.")
for item in enumerate_onnx_node_types(proto.graph, level=level, shapes=shapes):
yield item
elif isinstance(model, FunctionProto):
raise NotImplementedError(f"Not implemented for type {type(proto)}.")
else:
for inp in proto.input:
obs = dict(level=level, name=inp.name, kind="input")
obs.update(_info_type(inp.type))
yield obs
for init in proto.initializer:
obs = dict(level=level, name=init.name, kind="initializer")
obs.update(_info_type(init))
yield obs
for init in proto.sparse_initializer:
obs = dict(level=level, name=init.name, kind="sparse_initializer")
obs.update(_info_type(init))
yield obs
for node in proto.node:
obs = dict(
level=level,
name=node.name,
kind="Op",
domain=node.domain,
type=node.op_type,
inputs=",".join(node.input),
outputs=",".join(node.output),
input_types=",".join(
_info_type(shapes.get(i, None)).get("elem_type", "")
for i in node.input
),
output_types=",".join(
_info_type(shapes.get(i, None)).get("elem_type", "")
for i in node.output
),
)
yield obs
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
obs = dict(name=att.name, kind="attribute", level=level + 1)
yield obs
for item in enumerate_onnx_node_types(
att.g, level=level + 1, shapes=shapes
):
yield item
for out in node.output:
obs = dict(name=out, kind="result", level=level)
obs.update(_info_type(shapes.get(out, None)))
yield obs
for out in proto.output:
obs = dict(level=level, name=out.name, kind="output")
obs.update(_info_type(out.type))
yield obs
[docs]def enumerate_model_tensors(
model: ModelProto,
) -> Iterable[Tuple[TensorProto, bool]]:
"""
Enumerates all tensors in a model.
:param model: model to process
:return: iterator on a couple (TensorProto, bool),
the boolean indicates if the data is external
"""
from onnx.external_data_helper import (
_get_all_tensors,
uses_external_data,
)
for tensor in _get_all_tensors(model):
yield tensor, uses_external_data(tensor)
[docs]def onnx_merge_models(
m1: ModelProto, m2: ModelProto, io_map: List[Tuple[str, str]], verbose: int = 0
) -> ModelProto:
"""
Merges two models. The functions also checks that the model
have the same defined opsets (except for function).
If not, the most recent opset is selected.
:param m1: first model
:param m2: second model
:param io_map: mapping between outputs of the first model and
and the input of the second one
:param verbose: display some information if one of the model was updated
:return: new model
"""
opsets1 = {o.domain: o.version for o in m1.opset_import}
opsets2 = {o.domain: o.version for o in m2.opset_import}
update = {}
for k, v2 in opsets2.items():
if k not in opsets1:
continue
v1 = opsets1[k]
if v1 == v2:
continue
update[k] = max(v1, v2)
if len(update) > 0 and verbose > 0:
print(f"[onnx_merge_models] selected opsets: {update}")
if "" in update:
if opsets1[""] != update[""]:
if verbose:
print(
f"[onnx_merge_models] update model 1 from "
f"{opsets1['']} to {update['']}"
)
m1 = convert_version(m1, update[""])
if opsets2[""] != update[""]:
if verbose:
print(
f"[onnx_merge_models] update model 2 from "
f"{opsets2['']} to {update['']}"
)
m2 = convert_version(m2, update[""])
if m1.ir_version != m2.ir_version:
new_ir = max(m1.ir_version, m2.ir_version)
m1.ir_version = new_ir
m2.ir_version = new_ir
if verbose:
for k, v in update.items():
if k == "":
continue
print(f"[onnx_merge_models] no update implemented for domain {k!r} to {v}")
return merge_models(m1, m2, io_map=io_map)
def tree_ensemble_use_as_tensor_attributes(node: NodeProto) -> NodeProto:
"""
Uses only attributes suffixed with `_as_tensor` for tree ensemble operators.
:param node: node to update
:return: modified node
"""
if node.op_type in {"TreeEnsembleRegressor", "TreeEnsembleClassifier"}:
atts = {
"base_values",
"nodes_hitrates",
"nodes_values",
"target_weights",
"class_weights",
}
attributes = []
modified = False
for att in node.attribute:
if att.name not in atts:
attributes.append(att)
continue
floats = list(att.floats)
tensor = make_tensor(att.name, TensorProto.FLOAT, [len(floats)], floats)
att = make_attribute(att.name + "_as_tensor", tensor)
attributes.append(att)
modified = True
if not modified:
return node
new_node = make_node(
node.op_type, node.input, node.output, name=node.name, domain=node.domain
)
new_node.attribute.extend(attributes)
return new_node
raise NotImplementedError(
f"Unable to apply tree_ensemble_use_as_tensor_attributes "
f"on operator type {node.op_type}."
)
[docs]def convert_onnx_model(
onnx_model: Union[ModelProto, GraphProto, NodeProto, FunctionProto],
opsets: Dict[str, int],
recursive: bool = True,
use_as_tensor_attributes: bool = True,
verbose: int = 0,
_from_opset: Optional[Dict[str, int]] = None,
debug_info: Optional[List[str]] = None,
) -> Union[ModelProto, GraphProto, NodeProto, FunctionProto]:
"""
Upgrades a model to the latest opsets.
:param onnx_model: proto
:param opsets: list of opsets to update
:param recursive: looks into subgraphs
:param use_as_tensor_attributes: use attributes siffixed with `as_tensor` for trees
:param verbose: verbosity
:param _from_opset: tells which opset a node belongs too, only used when
onnx_model is a NodeProto
:param debug_info: unused
:return: new proto
"""
def _change_opsets(proto):
old_opsets = {d.domain: d.version for d in proto.opset_import}
old_opsets.update(opsets)
del proto.opset_import[:]
for k, v in old_opsets.items():
d = proto.opset_import.add()
d.domain = k
d.version = v
if isinstance(onnx_model, ModelProto):
if _from_opset is not None:
raise RuntimeError("_from_opset must be None for ModelProto.")
old_opsets = {d.domain: d.version for d in onnx_model.opset_import}
if verbose > 0:
print(
f"[convert_onnx_model] upgrade ModelProto from opset "
f"{old_opsets} to {opsets}"
)
new_opset = opsets.get("", old_opsets.get("", None))
if new_opset != old_opsets.get("", None):
onnx_model = convert_version(onnx_model, opsets)
new_onnx = _apply_optimisation_on_graph(
convert_onnx_model,
onnx_model,
recursive=recursive,
verbose=verbose,
opsets=opsets,
use_as_tensor_attributes=use_as_tensor_attributes,
_from_opset=old_opsets,
)
_change_opsets(new_onnx)
return new_onnx
is_function = isinstance(onnx_model, FunctionProto)
if not is_function and not isinstance(_from_opset, dict):
raise TypeError(f"_from_opset must a dictionary not {_from_opset!r}.")
if isinstance(onnx_model, NodeProto):
node = onnx_model
if verbose > 1:
print(
f"[convert_onnx_model] upgrade node {node.op_type!r} from opset "
f"{_from_opset.get(node.domain, 0)} to {opsets.get(node.domain, 0)}"
)
if (
_from_opset.get(node.domain, 0) == 0
or node.domain != "ai.onnx.ml"
or opsets.get(node.domain) is None
):
return node
if _from_opset[node.domain] != opsets[node.domain]:
if node.op_type in {"TreeEnsembleRegressor", "TreeEnsembleClassifier"}:
# nothing to do
pass
else:
raise NotImplementedError(
f"No upgrade is available from {_from_opset} to "
f"{opsets[node.domain]} for operator type "
f"{node.domain}.{node.op_type}."
)
if node.op_type in {"TreeEnsembleRegressor", "TreeEnsembleClassifier"}:
if use_as_tensor_attributes:
if _from_opset[node.domain] < 3 and opsets.get(node.domain, 0) < 3:
raise RuntimeError(
"Opset 3 is required when use_as_tensor_attributes is True."
)
new_node = tree_ensemble_use_as_tensor_attributes(node)
if verbose > 2:
atts = ", ".join(a.name for a in new_node.attribute)
print(f"[convert_onnx_model] {node.op_type}: attributes={atts!r}")
return new_node
return node
if is_function:
old_opsets = {d.domain: d.version for d in onnx_model.opset_import}
else:
old_opsets = _from_opset
if verbose > 1:
print(
f"[convert_onnx_model] upgrade {type(onnx_model)} from opset "
f"{old_opsets} to {opsets}"
)
nodes = onnx_model.node
new_nodes = []
for node in nodes:
new_nodes.append(
convert_onnx_model(
node,
recursive=recursive,
verbose=verbose,
opsets=opsets,
use_as_tensor_attributes=use_as_tensor_attributes,
_from_opset=_from_opset,
)
)
# Finally create the new graph.
nodes = list(filter(lambda n: n is not None, new_nodes))
if is_function:
onx = make_function(
onnx_model.domain,
onnx_model.name,
onnx_model.input,
onnx_model.output,
new_nodes,
opset_imports=onnx_model.opset_import, # opsets should have been updated
attributes=onnx_model.attribute,
doc_string=onnx_model.doc_string,
)
_change_opsets(onx)
return onx
graph = make_graph(
new_nodes,
onnx_model.name,
onnx_model.input,
onnx_model.output,
onnx_model.initializer,
sparse_initializer=onnx_model.sparse_initializer,
)
graph.value_info.extend(onnx_model.value_info)
return graph
[docs]def multiply_tree(node: NodeProto, n: int) -> NodeProto:
"""
Multiplies the number of trees in TreeEnsemble operator.
It replicates the existing trees.
:param node: tree ensemble operator
:param n: number of times the existing trees must be multiplied
:return: the new trees
"""
if not isinstance(node, NodeProto):
raise TypeError(f"node is not a NodeProto but {type(node)}.")
if not node.op_type.startswith("TreeEnsemble"):
raise ValueError(f"Unexpected node type {node.op_type!r}.")
args = [node.op_type, node.input, node.output]
kwargs = {"domain": node.domain}
for att in node.attribute:
if att.name in {"aggregate_function", "post_transform"}:
kwargs[att.name] = att.s
elif att.name in {"n_targets"}:
kwargs[att.name] = att.i
elif att.name == "classlabels_int64s":
ints = att.ints
if len(ints) > 0:
kwargs[att.name] = list(ints)
elif att.name == "classlabels_strings":
vals = att.strings
if len(vals) > 0:
kwargs[att.name] = list(vals)
elif att.name == "base_values":
fs = list(att.floats)
if len(fs) > 0:
kwargs.att[att.name] = fs
elif att.name.endswith("_as_tensor"):
v = to_array(att.t)
if att.name == "base_values_as_tensor":
if len(v.shape) > 0:
kwargs[att.name] = from_array(v)
else:
kwargs[att.name] = from_array(np.repeat(v, n))
elif att.name in {
"class_ids",
"class_nodeids",
"nodes_falsenodeids",
"nodes_featureids",
"nodes_missing_value_tracks_true",
"nodes_nodeids",
"nodes_truenodeids",
"target_ids",
"target_nodeids",
}:
kwargs[att.name] = list(att.ints) * n
elif att.name in {"class_treeids", "nodes_treeids", "target_treeids"}:
ints = []
arr = np.array(att.ints, dtype=np.int64)
for i in range(n):
ints.extend(arr.tolist())
arr += 1
kwargs[att.name] = ints
elif att.name in {"nodes_modes"}:
kwargs[att.name] = list(att.strings) * n
elif att.name in {
"nodes_hitrates",
"nodes_values",
"target_weights",
"class_weights",
}:
fs = list(att.floats)
if len(fs) > 0:
kwargs[att.name] = fs * n
return make_node(*args, **kwargs)