Source code for onnx_extended.tools.onnx_manipulations
import logging
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
from onnx import (
AttributeProto,
FunctionProto,
GraphProto,
ModelProto,
NodeProto,
ValueInfoProto,
TensorProto,
TypeProto,
shape_inference,
)
from onnx.helper import (
make_attribute,
make_graph,
make_function,
make_model,
make_tensor_value_info,
np_dtype_to_tensor_dtype,
set_model_props,
)
logger = logging.getLogger("onnx-extended")
def _make_att_graph(name: str, new_body: GraphProto) -> AttributeProto:
attr = AttributeProto()
attr.name = name
attr.g.CopyFrom(new_body)
attr.type = AttributeProto.GRAPH
return attr
def _make_node(
op_type: str,
inputs: List[str],
outputs: List[str],
name: Optional[str] = None,
doc_string: Optional[str] = None,
domain: str = "",
attributes: Optional[Dict[str, Any]] = None,
) -> NodeProto:
"""
Constructs a NodeProto.
:param op_type: (string): The name of the operator to construct
:param inputs: list of input names
:param outputs: list of output names
:param name: optional unique identifier for NodeProto
:param doc_string: optional documentation
string for NodeProto
:param domain: optional domain for NodeProto.
If it's None, we will just use default domain (which is empty)
:param attributes: the attributes of the node. The acceptable values
are documented in `make_attribute`.
:return: node
"""
node = NodeProto()
node.op_type = op_type
node.input.extend(inputs)
node.output.extend(outputs)
if name:
node.name = name
if doc_string:
node.doc_string = doc_string
if domain is not None:
node.domain = domain
if isinstance(attributes, dict):
if len(attributes) > 0:
node.attribute.extend(
make_attribute(key, value) for key, value in sorted(attributes.items())
)
elif attributes:
for att in attributes:
node.attribute.extend([att])
return node
def _apply_optimisation_on_graph(
fct: Callable,
onnx_model: Union[ModelProto, GraphProto, FunctionProto],
recursive: bool = True,
debug_info: Optional[List[str]] = None,
**kwargs: Dict[str, Any],
) -> Union[ModelProto, GraphProto, FunctionProto]:
"""
Applies an optimisation function *fct* on a graph
and not on the model.
:param fct: function to optimize like :func:`onnx_remove_node_unused`
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param debug_info: debug information (private)
:param kwargs: additional parameters
:return: new onnx _model
"""
if hasattr(onnx_model, "graph"):
if debug_info is None:
debug_info = []
graph = fct(onnx_model.graph, debug_info=debug_info + ["GRAPH"], **kwargs)
new_model = make_model(graph, functions=onnx_model.functions)
new_model.ir_version = onnx_model.ir_version
new_model.producer_name = onnx_model.producer_name
new_model.producer_version = onnx_model.producer_version
new_model.domain = onnx_model.domain
new_model.model_version = onnx_model.model_version
new_model.doc_string = onnx_model.doc_string
if hasattr(onnx_model, "value_info"):
graph.value_info.extend(onnx_model.value_info)
while len(new_model.opset_import) > 0:
new_model.opset_import.pop()
for oimp in onnx_model.opset_import:
op_set = new_model.opset_import.add()
op_set.domain = oimp.domain
op_set.version = oimp.version
return new_model
raise TypeError(
f"This function only works on 'ModelProto' anod not not on {type(onnx_model)}."
)
def _apply_remove_node_fct_node(
fct: Callable, node: NodeProto, recursive: bool, debug_info: str
) -> NodeProto:
"""
Applies an optimizing function on a subgraphs.
:param node: onnx node
:param recursive: does it in subgraphs as well
:return: new node
"""
if not hasattr(node, "attribute"):
return node
modified = 0
new_atts = []
for att in node.attribute:
if att.name in ("body", "then_branch", "else_branch"):
new_body = fct(
att.g, recursive=recursive, debug_info=debug_info + [att.name]
)
new_atts.append(_make_att_graph(att.name, new_body))
modified += 1
else:
new_atts.append(att)
if modified > 0:
new_node = _make_node(
node.op_type, node.input, node.output, name=node.name, attributes=new_atts
)
return new_node
return node
def _process_node(
node: NodeProto,
data: Dict,
edges: Dict,
paths: Dict,
prefix: str = "",
sep: str = ":X:",
path: Optional[List[str]] = None,
):
node_name = prefix + node.name
data[node_name, 1] = node
path = [] if path is None else path.copy()
paths[node_name, 1] = path
path = path.copy()
path.append(node_name)
for inp in node.input:
data[inp, 0] = node
edges[(inp, 0), (node_name, 1)] = node
paths[inp, 0] = path
if sep in node_name:
# We need to link an input to the parent node
# if the node is part of subgraph.
# path_r = paths[inp, 0]
if len(path) <= 1:
raise RuntimeError(
f"Unexpected path {path!r}, this may happen "
f"if sep={sep!r} is already used in the original model."
)
edges[(inp, 0), (path[-2], 1)] = node
for out in node.output:
data[out, 0] = node
paths[out, 0] = node_name
edges[(node_name, 1), (out, 0)] = node
if len(node.attribute) > 0:
for att in node.attribute:
if not hasattr(att, "g"):
continue
if not isinstance(att.g, GraphProto):
continue
for no in att.g.node:
_process_node(no, data, edges, paths, prefix=node_name + sep, path=path)
[docs]def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options):
"""
Removes unused nodes of the graph. An unused node
is not involved in the output computation.
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param debug_info: debug information (private)
:param options: unused
:return: new onnx _model
"""
if debug_info is None:
debug_info = [str(type(onnx_model)).rsplit(".", maxsplit=1)[-1].strip("'>")]
else:
debug_info = debug_info + [
str(type(onnx_model)).rsplit(".", maxsplit=1)[-1].strip("'>")
]
if hasattr(onnx_model, "graph"):
return _apply_optimisation_on_graph(
onnx_remove_node_unused,
onnx_model,
recursive=recursive,
debug_info=debug_info,
**options,
)
graph = onnx_model
logger.debug("onnx_remove_node_unused:begin with %d nodes.", len(graph.node))
is_function = isinstance(graph, FunctionProto)
data = {}
valid = {}
edges = {}
paths = {}
if not is_function:
for init in graph.initializer:
data[init.name, 0] = init
for node in graph.node:
_process_node(node, data, edges, paths)
for out in graph.output:
valid[out if is_function else out.name, 0] = True
modif = 1
while modif > 0:
modif = 0
for e1, e2 in edges: # pylint: disable=E1141
if valid.get(e2, False) and not valid.get(e1, False):
valid[e1] = True
modif += 1
new_nodes = [n for n in graph.node if (n.name, 1) in valid]
if not is_function:
new_inits = [n for n in graph.initializer if (n.name, 0) in valid]
if recursive:
# Handles subgraphs.
for i in range(len(new_nodes)):
node = new_nodes[i]
if node is None or not (node.attribute):
continue
new_nodes[i] = _apply_remove_node_fct_node(
onnx_remove_node_unused,
node,
recursive=True,
debug_info=debug_info + [node.name],
)
# Finally create the new graph.
nodes = list(filter(lambda n: n is not None, new_nodes))
if is_function:
logger.debug("onnx_remove_node_unused:end function with %d nodes.", len(nodes))
return make_function(
onnx_model.domain,
onnx_model.name,
onnx_model.input,
onnx_model.output,
nodes,
opset_imports=onnx_model.opset_import,
attributes=onnx_model.attribute,
doc_string=onnx_model.doc_string,
)
graph = make_graph(
nodes, onnx_model.name, onnx_model.input, onnx_model.output, new_inits
)
graph.value_info.extend(onnx_model.value_info)
logger.debug("onnx_remove_node_unused:end graph with %d nodes.", len(nodes))
return graph
def _guess_proto_dtype(dtype) -> int:
return np_dtype_to_tensor_dtype(dtype)
def get_tensor_shape(
obj: Union[ValueInfoProto, TypeProto, TensorProto]
) -> Optional[List[Union[int, str, None]]]:
"""
Returns the shape if that makes sense for this object.
"""
if isinstance(obj, ValueInfoProto):
return get_tensor_shape(obj.type)
elif not isinstance(obj, TypeProto):
raise TypeError(f"Unexpected type {type(obj)!r}.")
if not obj.tensor_type.HasField("shape"):
return None
shape = []
for d in obj.tensor_type.shape.dim:
v = d.dim_value if d.dim_value > 0 else d.dim_param
shape.append(v)
if len(shape) == 0:
return shape
return list(None if s in (0, "") else s for s in shape)
def enumerate_model_node_outputs(
model: ModelProto, add_node: bool = False, order: bool = False
) -> Iterable:
"""
Enumerates all the nodes of a model.
:param model: :epkg:`ONNX` graph
:param add_node: if False, the function enumerates
all output names from every node, otherwise, it
enumerates tuple (output name, node)
:param order: goes through outputs following the graph order
:return: enumerator
"""
if not hasattr(model, "graph"):
raise TypeError(f"Parameter model is not an ONNX model but {type(model)}")
if order:
edges = []
dorder = {}
node_names = {}
for inp in model.graph.input:
dorder[0, inp.name] = 0
for node in model.graph.node:
dorder[1, node.name] = 0
for i in node.input:
edges.append(("in", i, node.name))
for o in node.output:
edges.append(("out", o, node.name))
node_names[o] = node
dorder[0, o] = 0
modif = 1
n_iter = 0
while modif > 0 and n_iter <= len(model.graph.node):
modif = 0
n_iter += 1
for kind, data_name, node_name in edges:
if kind == "in":
if (0, data_name) not in dorder:
continue
if dorder[0, data_name] + 1 > dorder[1, node_name]:
modif += 1
dorder[1, node_name] = dorder[0, data_name] + 1
else:
if dorder[1, node_name] + 1 > dorder[0, data_name]:
modif += 1
dorder[0, data_name] = dorder[1, node_name] + 1
orders = [(v, k) for k, v in dorder.items()]
orders.sort()
for _, k in orders:
if k[0] == 1:
continue
out = k[1]
if out not in node_names:
continue
yield (out, node_names[out]) if add_node else out
else:
for node in model.graph.node:
for out in node.output:
yield (out, node) if add_node else out
[docs]def select_model_inputs_outputs(
model: ModelProto,
outputs: Optional[List[str]] = None,
inputs: Optional[List[str]] = None,
infer_shapes: bool = True,
overwrite: Optional[Dict[str, Any]] = None,
remove_unused: bool = True,
verbose: int = 0,
):
"""
Takes a model and changes its outputs.
:param model: :epkg:`ONNX` model
:param inputs: new inputs, same ones if None
:param outputs: new outputs, same ones if None
:param infer_shapes: infer inputs and outputs shapes
:param overwrite: overwrite type and shapes for
inputs or outputs, *overwrite* is a
dictionary `{'name': (numpy dtype, shape)}`
:param remove_unused: remove unused nodes from the graph
:param verbose: display information while converting
:return: modified model
The function removes unneeded nodes.
The following example shows how to change the inputs of model
to bypass the first nodes. Shape inferences fails to determine
the new inputs type. They need to be overwritten.
`verbose=1` shows the number of deleted nodes.
::
import onnx
from onnx_extended.tools.onnx_manipulations import select_model_inputs_outputs
onx = onnx.load(path)
onx2 = select_model_inputs_outputs(
onx, inputs=["a", "b"],
infer_shapes=True, verbose=1,
overwrite={'a': (numpy.int32, None), 'b': (numpy.int64, None)})
onnx.save(onx2, path2)
"""
if inputs is not None and not isinstance(inputs, list):
inputs = [inputs]
if outputs is not None and not isinstance(outputs, list):
outputs = [outputs]
if inputs is None:
inputs = [i.name for i in model.graph.input]
if outputs is None:
outputs = [o.name for o in model.graph.output]
mark_var = {}
for out in enumerate_model_node_outputs(model):
mark_var[out] = 0
for inp in inputs:
mark_var[inp] = 0
for out in outputs:
if out not in mark_var:
raise ValueError(f"Output '{out}' not found in model.")
mark_var[out] = 1
nodes = list(model.graph.node[::-1])
mark_op = {}
for node in list(nodes):
mark_op[id(node)] = 0
# We mark all the nodes we need to keep.
nb = 1
while nb > 0:
nb = 0
for node in nodes:
if mark_op[id(node)] == 1:
continue
mod = False
for out in node.output:
if mark_var[out] == 1:
mark_op[id(node)] = 1
mod = True
break
if not mod:
continue
hidden = get_hidden_inputs([node])
node_inputs = list(node.input) + list(hidden)
nb += 1
for inp in node_inputs:
if inp in inputs:
continue
if mark_var.get(inp, 0) == 1:
continue
mark_var[inp] = 1
nb += 1
# All nodes verifies mark_op[node.name] == 1
keep_nodes = [node for node in nodes[::-1] if mark_op[id(node)] == 1]
if verbose > 1:
for node in nodes:
s = "+" if mark_op[id(node)] == 1 else "-"
logger.info(
"[select_model_inputs_outputs] %s %s (%s) -> %s [%s]"
% (
s,
node.op_type,
", ".join(node.input),
", ".join(node.output),
node.name,
)
)
known_shapes = {}
if infer_shapes:
shapes = shape_inference.infer_shapes(model)
for shape in shapes.graph.value_info:
known_shapes[shape.name] = shape.type
for shape in shapes.graph.input:
known_shapes[shape.name] = shape.type
for shape in shapes.graph.output:
known_shapes[shape.name] = shape.type
else:
for shape in model.graph.input:
known_shapes[shape.name] = shape.type
for shape in model.graph.output:
known_shapes[shape.name] = shape.type
var_in = []
for name in inputs:
if overwrite is not None and name in overwrite:
dtype, shape = overwrite[name]
proto_dtype = _guess_proto_dtype(dtype)
value_info = make_tensor_value_info(name, proto_dtype, shape)
elif name in known_shapes:
info = known_shapes[name].tensor_type
proto_dtype = info.elem_type
if proto_dtype == 0:
value_info = ValueInfoProto()
value_info.name = name
else:
shape = get_tensor_shape(known_shapes[name])
value_info = make_tensor_value_info(name, proto_dtype, shape)
else:
value_info = ValueInfoProto()
value_info.name = name
var_in.append(value_info)
var_out = []
for name in outputs:
if overwrite is not None and name in overwrite:
dtype, shape = overwrite[name]
proto_dtype = _guess_proto_dtype(dtype)
value_info = make_tensor_value_info(name, proto_dtype, shape)
elif name in known_shapes:
info = known_shapes[name].tensor_type
proto_dtype = info.elem_type
if proto_dtype == 0:
value_info = ValueInfoProto()
value_info.name = name
else:
shape = get_tensor_shape(known_shapes[name])
value_info = make_tensor_value_info(name, proto_dtype, shape)
else:
value_info = ValueInfoProto()
value_info.name = name
var_out.append(value_info)
if verbose > 0:
logger.info(
"[select_model_inputs_outputs] nodes %r --> %r"
% (len(model.graph.node), len(keep_nodes))
)
logger.info(
"[select_model_inputs_outputs] inputs: %r" % [_.name for _ in var_in]
)
logger.info(
"[select_model_inputs_outputs] inputs: %r" % [_.name for _ in var_out]
)
graph = make_graph(
keep_nodes,
model.graph.name,
var_in,
var_out,
model.graph.initializer,
sparse_initializer=model.graph.sparse_initializer,
)
onnx_model = make_model(graph, functions=model.functions)
onnx_model.ir_version = model.ir_version
onnx_model.producer_name = model.producer_name
onnx_model.producer_version = model.producer_version
onnx_model.domain = model.domain
onnx_model.model_version = model.model_version
onnx_model.doc_string = model.doc_string
if len(model.metadata_props) > 0:
values = {p.key: p.value for p in model.metadata_props}
set_model_props(onnx_model, values)
del onnx_model.opset_import[:]
for oimp in model.opset_import:
op_set = onnx_model.opset_import.add()
op_set.domain = oimp.domain
op_set.version = oimp.version
# remove unused nodes
if remove_unused:
onnx_model = onnx_remove_node_unused(onnx_model, recursive=False)
return onnx_model