Source code for onnx_extended.tools.onnx_inline

import pprint
from collections import Counter
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
from onnx import (
    AttributeProto,
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    ValueInfoProto,
)
from onnx.helper import (
    make_graph,
    make_function,
    make_model,
    make_node,
    make_operatorsetid,
    make_attribute,
    make_value_info,
)


def enumerate_onnx_names(
    onx: Union[FunctionProto, GraphProto, ModelProto]
) -> Iterator[str]:
    """
    Enumerates all existing names in one ONNX graph
    (ModelProto, FunctionProto, GraphProto).
    The function is recursive.

    :param onx: one onnx object
    :return: iterator on names
    """
    if hasattr(onx, "graph"):
        for i in onx.graph.initializer:
            yield i.name
        for i in onx.graph.input:
            yield i.name
        for i in onx.graph.output:
            yield i.name
        nodes = onx.graph.node
    elif hasattr(onx, "initializer"):
        for i in onx.initializer:
            yield i.name
        for i in onx.input:
            yield i.name
        for i in onx.output:
            yield i.name
        nodes = onx.node
    else:
        if hasattr(onx, "input"):
            for i in onx.input:
                yield i
        if hasattr(onx, "output"):
            for i in onx.output:
                yield i
        nodes = onx.node
    for node in nodes:
        for i in node.input:
            yield i
        for o in node.output:
            yield o
        for att in node.attribute:
            if (
                att.type == AttributeProto.GRAPH
                and hasattr(att, "g")
                and att.g is not None
            ):
                for n in enumerate_onnx_names(att.g):
                    yield n


def enumerate_onnx_nodes(
    onx: Union[FunctionProto, GraphProto, ModelProto]
) -> Iterator[str]:
    """
    Enumerates all nodes in one ONNX graph
    (ModelProto, FunctionProto, GraphProto).
    The function is recursive.

    :param onx: one onnx object
    :return: iterator on names
    """
    if isinstance(onx, list):
        nodes = onx
    elif hasattr(onx, "graph"):
        nodes = onx.graph.node
    else:
        nodes = onx.node
    for node in nodes:
        yield node
        for att in node.attribute:
            if (
                att.type == AttributeProto.GRAPH
                and hasattr(att, "g")
                and att.g is not None
            ):
                for n in enumerate_onnx_nodes(att.g):
                    yield n


def _get_new_name(
    prefix: str, name: Union[str, ValueInfoProto], existing_names: Set[str]
) -> str:
    opt = f"{prefix}_{name}_0"
    i = 0
    while opt in existing_names:
        i += 1
        opt = "%s_%s_%d" % (prefix, name, i)
    existing_names.add(opt)
    return opt


def onnx_subgraphs_level(obj: Union[FunctionProto, GraphProto, ModelProto]) -> int:
    """
    Returns the depth of the graph.

    :param obj: onnx object
    :return: integer
    """
    if isinstance(obj, ModelProto):
        return onnx_subgraphs_level(obj.graph)
    best = 0
    for node in obj.node:
        for att in node.attribute:
            if (
                att.type == AttributeProto.GRAPH
                and hasattr(att, "g")
                and att.g is not None
            ):
                m = onnx_subgraphs_level(att.g)
                if m > best:
                    best = m
    return best + 1


class _inline_mapping(dict):
    """
    Overwrites class dictionary to debug more easily.

    :param verbose: verbosity
    :param level: sub graph level
    """

    def __init__(self, verbose: int, level: int):
        dict.__init__(self)
        self._verbose = verbose
        self._level = level

    def __setitem__(self, key: str, value: Any):
        "Adds a value."
        if self._verbose > 3:
            print(
                "[_inline_mapping-dict-addkv] %s + %r: %r"
                % ("  " * self._level, key, value)
            )
        assert (
            key not in self
        ), "Key %r was already added (with value %r, new one is %r)." "" % (
            key,
            self[key],
            value,
        )
        dict.__setitem__(self, key, value)

    def update(self, d: Dict[str, Any]):
        "Updates many values."
        for k, v in d.items():
            self[k] = v

    def copy(self) -> Dict[str, Any]:
        "Returns a copy."
        m = _inline_mapping(self._verbose, self._level)
        for k, v in self.items():
            m[k] = v
        return m

    def remove(self, o: str):
        "Removes one element."
        assert o in self, f"Cannot remove a key {o!r}."
        self.pop(o)


def _onnx_inline_function_graph(
    graph: GraphProto,
    protos: Dict[str, FunctionProto],
    existing_names: Set[str],
    mapping: _inline_mapping,
    verbose: int,
    rename: bool,
    level: int,
) -> Tuple[
    Union[FunctionProto, GraphProto, ModelProto], List[Union[ValueInfoProto, NodeProto]]
]:
    if len(graph.node) == 0:
        # Outputs have still to be renamed.
        graph0 = graph
        if verbose > 1:
            print(
                "[onnx_inline_function-graph] %s visit0 graph=%d rename=%r "
                "len(mapping)=%d begin"
                % ("  " * level, id(graph), rename, len(mapping))
            )
        if rename:
            modified_nodes = []
            mapping = mapping.copy()
            for i in graph.input:
                mapping[i.name] = i.name
            for i in graph.initializer:
                mapping[i.name] = i.name
            for i in graph.sparse_initializer:
                mapping[i.name] = i.name
            outputs = []
            for o in graph.output:
                no = make_value_info(mapping[o.name], o.type)
                if no.name != o.name:
                    modified_nodes.append(o)
                    outputs.append(no)
                else:
                    outputs.append(o)
            if len(modified_nodes) > 0:
                graph = make_graph(
                    [],
                    graph.name,
                    graph.input,
                    outputs,
                    graph.initializer,
                    doc_string=graph.doc_string,
                    sparse_initializer=list(graph.sparse_initializer),
                )
        else:
            modified_nodes = []

        if verbose > 1:
            print(
                "[onnx_inline_function-graph] %s visit graph=%d end "
                "changed=%r len(modified_nodes)=%d"
                % (
                    "  " * level,
                    id(graph0),
                    id(graph0) != id(graph),
                    len(modified_nodes),
                )
            )

        return graph, modified_nodes

    graph0 = graph
    mapping = mapping.copy()
    init = list(graph.initializer)
    init_sparse = list(graph.sparse_initializer)
    inputs = list(graph.input)
    modified_nodes = []
    outputs = list(graph.output)

    if verbose > 1:
        print(
            "[onnx_inline_function-graph] %s >visit graph=%d rename=%r "
            "len(mapping)=%d begin" % ("  " * level, id(graph), rename, len(mapping))
        )

    output_names = [o.name for o in outputs]
    for i in init:
        mapping[i.name] = i.name
    for i in init_sparse:
        mapping[i.name] = i.name
    for i in inputs:
        mapping[i.name] = i.name

    # first step, replace names
    nodes = []
    for node in list(graph.node):
        mod = 0
        inp = []
        for i in node.input:
            assert (
                i in mapping
            ), "Cannot find input %r in %s for node (level=%d)\n%r." % (
                i,
                pprint.pformat(mapping),
                level,
                node,
            )
            inp.append(mapping[i])
            if mapping[i] != i:
                mod += 1
        out = []
        for o in node.output:
            new_o = o
            if rename:
                if o not in output_names:
                    new_o = _get_new_name("_inl", o, existing_names)
                if o in mapping:
                    # See below.
                    mapping.remove(o)
            elif o in mapping:
                # That means the main contains a result node but is overwritten by
                # the subgraph. The local variable cannot be reached anymore,
                # we remove it.
                mapping.remove(o)
                if o in node.input:
                    new_o = _get_new_name("_inl", o, existing_names)
                if verbose > 3:
                    print(
                        "[onnx_inline_function-renam] %s node %r(%r): %r -> %r "
                        "overwrite result (%r -> %r)."
                        % (
                            "  " * level,
                            node.op_type,
                            node.name,
                            node.input,
                            node.output,
                            o,
                            new_o,
                        )
                    )
            out.append(new_o)
            mapping[o] = new_o
            if o != new_o:
                mapping[new_o] = new_o
                mod += 1

        if verbose > 3:
            print(
                "[onnx_inline_function-renam] %s rep node %r(%r): %r -> %r"
                % ("  " * level, node.op_type, node.name, node.input, node.output)
            )
        new_node = make_node(
            node.op_type,
            inp,
            out,
            domain=node.domain,
            name=_get_new_name("_inln", node.name, existing_names),
        )
        for att in node.attribute:
            if (
                att.type == AttributeProto.GRAPH
                and hasattr(att, "g")
                and att.g is not None
            ):
                g, m = _onnx_inline_function_graph(
                    att.g,
                    protos,
                    existing_names=existing_names,
                    verbose=verbose,
                    mapping=mapping,
                    rename=rename,
                    level=level + 1,
                )
                if len(m) > 0:
                    att = make_attribute(att.name, g)
                    mod += len(m)
                else:
                    att = make_attribute(att.name, att.g)
            new_node.attribute.append(att)
        if mod > 0:
            if verbose > 2:
                print(
                    "[onnx_inline_function-renam] %s add node %r(%r): %r -> %r"
                    % (
                        "  " * level,
                        new_node.op_type,
                        new_node.name,
                        new_node.input,
                        new_node.output,
                    )
                )
            nodes.append(new_node)
            modified_nodes.append(node)
        else:
            nodes.append(node)

    if len(modified_nodes) > 0:
        if verbose > 1:
            print(
                "[onnx_inline_function-graph] %s -1 graph=%d "
                "len(modified_nodes)=%d"
                % ("  " * level, id(graph), len(modified_nodes))
            )

        graph = make_graph(
            nodes,
            graph.name,
            inputs,
            outputs,
            init,
            doc_string=graph.doc_string,
            sparse_initializer=list(graph.sparse_initializer),
        )
    elif not rename:
        # no modification, let's check the node hiding a functions
        new_nodes = []
        for node in nodes:
            nnodes, m = _onnx_inline_function_node(
                node, protos, existing_names, verbose, level=level
            )
            if len(m) > 0:
                if verbose > 0:
                    print(
                        "[onnx_inline_function-subgr] %s replaced node %r (%r) "
                        "with %d nodes (id=%r) -- %r -> %r"
                        % (
                            "  " * level,
                            node.name,
                            node.op_type,
                            len(nnodes),
                            id(node),
                            node.input,
                            node.output,
                        )
                    )
                new_nodes.extend(nnodes)
                modified_nodes.extend(m)
            else:
                new_nodes.append(node)
        if len(modified_nodes) > 0:
            if verbose > 1:
                print(
                    "[onnx_inline_function-graph] %s -2 graph=%d "
                    "len(modified_nodes)=%d"
                    % ("  " * level, id(graph), len(modified_nodes))
                )

            nodes = new_nodes
            graph = make_graph(
                nodes,
                graph.name,
                inputs,
                outputs,
                init,
                doc_string=graph.doc_string,
                sparse_initializer=list(graph.sparse_initializer),
            )

    if verbose > 1:
        print(
            "[onnx_inline_function-graph] %s <visit graph=%d end "
            "changed=%r len(modified_nodes)=%d"
            % ("  " * level, id(graph0), id(graph0) != id(graph), len(modified_nodes))
        )

    return graph, modified_nodes


def _onnx_inline_function_node(
    node: NodeProto,
    protos: Dict[str, FunctionProto],
    existing_names: Set[str],
    verbose: int,
    level: int,
) -> Tuple[List[NodeProto], List[Union[ValueInfoProto, NodeProto]]]:
    """
    Inline a node.

    :param node: node to inline
    :param protos: known functions
    :param existing_names: names which cannot be used
    :param verbose: verbosity level
    :param level: level of subgraphs
    :return: new nodes, modified nodes
    """
    # The function does not rename input or output
    # of the node, it just replaces the node but a function
    # if the function exists.
    modified_nodes = []
    key = node.domain, node.op_type
    if key in protos:
        proto = protos[key]
        assert isinstance(
            proto, FunctionProto
        ), "Prototype for key=%r must be a Function Proto, not %r." % (key, type(proto))
        modified_nodes.append(node)
        new_nodes = []
        mapping = _inline_mapping(verbose, level)
        prefix = "_inl"

        for fr, to in zip(node.input, proto.input):
            n = make_node("Identity", [fr], [_get_new_name(prefix, to, existing_names)])
            if verbose > 2:
                print(
                    "[onnx_inline_function-ninpu] %s add node %r(%r): %r -> %r"
                    % ("  " * level, n.op_type, n.name, n.input, n.output)
                )
            mapping[to] = n.output[0]
            if to != n.output[0]:
                mapping[n.output[0]] = n.output[0]
            new_nodes.append(n)

        attributes = {att.name: att for att in node.attribute}

        for nn in proto.node:
            new_input = [mapping[i] for i in nn.input]
            new_output = [_get_new_name(prefix, o, existing_names) for o in nn.output]
            mapping.update({o: oo for o, oo in zip(nn.output, new_output)})
            mapping.update({oo: oo for oo in new_output})
            new_node = make_node(
                nn.op_type,
                new_input,
                new_output,
                domain=nn.domain,
                name=_get_new_name(prefix, nn.name, existing_names),
            )
            if verbose > 3:
                print(
                    "[onnx_inline_function-nnode]   %s rep node %r(%r): %r -> %r"
                    % ("  " * level, nn.op_type, nn.name, nn.input, nn.output)
                )
            if verbose > 2:
                print(
                    "[onnx_inline_function-nnode] %s add node %r(%r): %r -> %r"
                    % (
                        "  " * level,
                        new_node.op_type,
                        new_node.name,
                        new_node.input,
                        new_node.output,
                    )
                )
            for att in nn.attribute:
                if hasattr(att, "ref_attr_name") and att.ref_attr_name:
                    # linked attribute
                    assert att.ref_attr_name in attributes, (
                        f"A linked attribute {att.ref_attr_name!r} "
                        f"cannot be found in {list(sorted(attributes))} "
                        f"for operator type {nn.op_type!r} and attribute {att.name!r}."
                    )
                    new_att = AttributeProto()
                    new_att.ParseFromString(
                        attributes[att.ref_attr_name].SerializeToString()
                    )
                    new_att.name = att.name
                    att = new_att
                    if verbose > 3:
                        print(
                            "[onnx_inline_function-funct]   %s fct=%r att %r linked to %r"
                            % ("  " * level, key, att.name, att.ref_attr_name)
                        )
                elif (
                    att.type == AttributeProto.GRAPH
                    and hasattr(att, "g")
                    and att.g is not None
                ):
                    if verbose > 1:
                        print(
                            "[onnx_inline_function-funct] %s fct=%r graph=%d node=%d"
                            % ("  " * level, key, id(att.g), id(new_node))
                        )

                    g, m = _onnx_inline_function_graph(
                        att.g,
                        protos,
                        existing_names=existing_names,
                        verbose=verbose,
                        mapping=mapping,
                        rename=True,
                        level=level + 1,
                    )
                    if len(m) > 0:
                        att = make_attribute(att.name, g)
                    else:
                        att = make_attribute(att.name, att.g)
                new_node.attribute.append(att)
            new_nodes.append(new_node)

        for fr, to in zip(proto.output, node.output):
            n = make_node("Identity", [mapping[fr]], [to])
            if verbose > 2:
                print(
                    "[onnx_inline_function-noutt] %s add node %r(%r): %r -> %r"
                    % ("  " * level, n.op_type, n.name, n.input, n.output)
                )
            new_nodes.append(n)
    else:
        new_nodes = [node]
        modified_nodes = []
    return new_nodes, modified_nodes


[docs]def onnx_inline_function( obj: Union[FunctionProto, GraphProto, ModelProto], protos: Optional[Dict[str, Any]] = None, existing_names: Optional[Set[str]] = None, verbose: int = 0, ) -> Tuple[Union[FunctionProto, GraphProto, ModelProto], List[NodeProto]]: """ Inlines functions in an ONNX graph. :param obj: onnx graph, FunctionProto, GraphProto, ModelProto :param protos: if None, the function assumes *obj* is of type ModelProto and the goal is to inline every function. If *protos* a list of strings, the function only inlines the functions in that list. If *protos* is a dictionary `{ (domain, type): FunctionProto }`, the function replaces every node `(domain, type)` by the code given in this dictionary :param existing_names: no new name will be taken in that set :param verbose: verbosity :return: modified object, list of modified nodes """ if isinstance(obj, ModelProto): if verbose > 0: print("[onnx_inline_function] type=%r graph=%d" % (type(obj), id(obj))) if protos is None: fct = [f.name for f in obj.functions] ex_names = set(enumerate_onnx_names(obj)) if existing_names is not None: ex_names |= existing_names return onnx_inline_function( obj, fct, existing_names=ex_names, verbose=verbose ) if isinstance(protos, list): ex_names = set(enumerate_onnx_names(obj)) if existing_names is not None: ex_names |= existing_names protos = {(f.domain, f.name): f for f in obj.functions} return onnx_inline_function( obj, protos, existing_names=ex_names, verbose=verbose ) if isinstance(protos, list): protos = {(f.domain, f.name): f for f in protos} assert isinstance( protos, dict ), "obj is of type %r and protos must be a dictionary not %r." % ( type(obj), type(protos), ) if isinstance(obj, ModelProto): new_graph, m = onnx_inline_function(obj.graph, protos, verbose=verbose) if len(new_graph.initializer) != len(obj.graph.initializer): raise RuntimeError( "Mismatched number of initializers %d != %d." % (len(new_graph.initializer), len(obj.graph.initializer)) ) if len(new_graph.sparse_initializer) != len(obj.graph.sparse_initializer): raise RuntimeError( "Mismatched number of initializers %d != %d." % (len(new_graph.sparse_initializer), len(obj.graph.sparse_initializer)) ) new_functions = [] distri = Counter((n.domain, n.op_type) for n in enumerate_onnx_nodes(new_graph)) opsets = {op.domain: op.version for op in obj.opset_import} for f in obj.functions: key = f.domain, f.name if key not in protos: new_functions.append(f) elif key in distri: raise RuntimeError( "Function %r still appears in the graph, " "distibution=%s." % (key, pprint.pformat(distri)) ) if f.domain not in opsets: opsets[f.domain] = 1 return ( make_model( new_graph, functions=new_functions, opset_imports=[make_operatorsetid(k, v) for k, v in opsets.items()], producer_name=obj.producer_name, producer_version=obj.producer_version, ir_version=obj.ir_version, doc_string=obj.doc_string, domain=obj.domain, model_version=obj.model_version, ), m, ) # FunctionProto, GraphProto if existing_names is None: existing_names = set(enumerate_onnx_names(obj)) if verbose > 0: print("[onnx_inline_function] type=%r graph=%d begin" % (type(obj), id(obj))) distri = Counter((n.domain, n.op_type) for n in enumerate_onnx_nodes(obj)) new_nodes = list(obj.node) modified_nodes = [] n_iter = 0 max_iter = onnx_subgraphs_level(obj) + 1 modified = 1 while modified > 0 and n_iter < max_iter: if verbose > 0: print(f"[onnx_inline_function] start iteration {n_iter!r}") # local context mapping = _inline_mapping(verbose, level=0) if isinstance(obj, GraphProto): mapping.update({i.name: i.name for i in obj.initializer}) mapping.update({i.name: i.name for i in obj.sparse_initializer}) for i in obj.input: if i.name not in mapping: mapping[i.name] = i.name elif isinstance(obj, FunctionProto): mapping.update({i: i for i in obj.input}) else: raise TypeError(f"Unexpected type for obj: {type(obj)!r}.") # loop on nodes old_nodes = new_nodes modified = 0 new_nodes = [] for node in old_nodes: nnodes, m = _onnx_inline_function_node( node, protos, existing_names, verbose, level=0 ) mapping.update({o: o for o in node.output}) if len(m) > 0: if verbose > 0: print( "[onnx_inline_function] replaced node %r (%r) " "with %d nodes (id=%r) -- %r -> %r (iter=%r)" % ( node.name, node.op_type, len(nnodes), id(node), node.input, node.output, n_iter, ) ) modified += len(m) new_nodes.extend(nnodes) modified_nodes.extend(m) else: has_graph = False new_attributes = [] for att in node.attribute: if ( att.type == AttributeProto.GRAPH and hasattr(att, "g") and att.g is not None ): g, m = _onnx_inline_function_graph( att.g, protos, verbose=verbose, existing_names=existing_names, mapping=mapping, rename=False, level=1, ) if len(m) > 0: modified_nodes.extend(m) modified_nodes.append(node) modified += 1 + len(m) has_graph = True att = make_attribute(att.name, g) new_attributes.append(att) if has_graph: new_node = make_node( node.op_type, node.input, node.output, domain=node.domain, name=node.name, ) new_node.attribute.extend(new_attributes) new_nodes.append(new_node) else: # we still need to check that this subgraph does # not include a function new_nodes.append(node) n_iter += 1 if verbose > 0: total_node = len(list(enumerate_onnx_nodes(new_nodes))) print( "[onnx_inline_function] n_iter=%r/%r nodes=%r modified=%r " "n_nodes=%d total=%d" % ( n_iter, max_iter, len(obj.node), modified, len(new_nodes), total_node, ) ) if verbose > 0: print( "[onnx_inline_function] type=%r graph=%d end with %d " "modified nodes" % (type(obj), id(obj), len(modified_nodes)) ) distri2 = Counter( (n.domain, n.op_type) for n in enumerate_onnx_nodes(new_nodes) ) if distri != distri2: print("[onnx_inline_function] BEFORE") for k, v in sorted(distri.items()): print("[onnx_inline_function] %d -- %s" % (v, k)) print("[onnx_inline_function] AFTER") for k, v in sorted(distri2.items()): print("[onnx_inline_function] %d -- %s" % (v, k)) if isinstance(obj, FunctionProto): return ( make_function( domain=obj.domain, fname=obj.name, inputs=obj.input, outputs=obj.output, nodes=new_nodes, opset_imports=[ make_operatorsetid(op.domain, op.version) for op in obj.opset_import ], doc_string=obj.doc_string, attributes=obj.attribute, ), modified_nodes, ) if isinstance(obj, GraphProto): return ( make_graph( new_nodes, obj.name, list(obj.input), list(obj.output), list(obj.initializer), doc_string=obj.doc_string, sparse_initializer=list(obj.sparse_initializer), ), modified_nodes, ) raise TypeError(f"Unexpected type for obj {type(obj)!r}.")