Source code for experimental_experiment.xoptim.graph_builder_optim

import os
import pprint
import time
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
from onnx import AttributeProto, NodeProto, TensorProto
from onnx.shape_inference import infer_shapes
import onnx.helper as oh
import onnx.numpy_helper as onh
from ..xbuilder._onnx_helper import enumerate_subgraphs
from ..xbuilder.type_inference import infer_types
from .patterns_api import MatchResult, PatternOptimization
from .patterns import get_default_patterns


def _count(matches):
    stats = {}
    for n in matches:
        cl = n[0].__class__.__name__
        if cl in stats:
            stats[cl] += 1
        else:
            stats[cl] = 1
    return ", ".join([f"{v}*{k}" for k, v in stats.items()])


[docs] class GraphBuilderPatternOptimization: """ Implements optimization after the conversion is done. The differences between the two models can be display with a command line such as: :: python -m onnx_array_api compare -m1 <model.onnx> -m2 <optimized.onnx> -m nodes -c 80 This class assumes a pattern cannot reuse an existing name. :param builder: GraphBuilder :param patterns: list of patterns to apply :param recursive: goes through subgraphs :param verifies: verifies the model but it takes time :param verbose: verbosity :param dump_applied_patterns: dump applied patterns in a folder, the users can check every pattern dumped as a :epkg:`FunctionProto` :param processor: optimization should be made for this processor or this list of processors (comma separated value) """ def __init__( self, builder: "GraphBuilder", # noqa: F821 patterns: Optional[List[PatternOptimization]] = None, recursive: bool = False, verifies: bool = False, verbose: int = 0, dump_applied_patterns: Optional[str] = None, processor: str = "CPU", ): assert processor in { "CUDA", "CPU", "CPU,CUDA", }, ( f"Unknown processor {processor!r}, " f"if should be string with comma separated value" ) self.builder = builder self.verbose = max(verbose, int(os.environ.get("LOG_PATTERN_OPTIMIZE", "0"))) self.patterns = patterns or get_default_patterns(self.verbose) self.recursive = recursive self.verifies = verifies self.dump_applied_patterns = dump_applied_patterns self.processor = processor self._build() # This assume a name is given once and # no constant can replace an existing one. # _build method should not change it. self._cache_computed_constant = {}
[docs] def has_processor(self, processor: str) -> bool: """ Checks the process is on the list of used processors. """ return processor in self.processor
@property def nodes(self) -> List[NodeProto]: "property" return self.builder.nodes @property def input_names(self) -> List[str]: "property" return self.builder.input_names @property def inputs(self) -> List[Any]: "property" return self.builder.inputs @property def output_names(self) -> List[str]: "property" return self.builder.output_names @property def outputs(self) -> List[Any]: "property" return self.builder.outputs @property def opsets(self): "property" return self.builder.opsets
[docs] def iter_nodes(self) -> Iterator: "iterator" yield from self.builder.nodes
def _build(self): """ Builds successor and predecessor. """ self.positions_ = {} self.nodes_ = {} self.outputs_ = {o.name for o in self.builder.outputs} for i, node in enumerate(self.builder.nodes): key = id(node) self.nodes_[key] = node self.positions_[key] = i self.set_output_names_ = set(self.builder.output_names) self.predecessors_ = {} self.successors_ = {} successors_id = {} self.used_ = set() for k, v in self.nodes_.items(): assert isinstance(v, NodeProto), f"Unexpected type {type(v)} for node {k}" for o in v.output: self.predecessors_[o] = k for i in v.input: if i not in self.successors_: self.successors_[i] = [] successors_id[i] = set() if id(k) not in successors_id[i]: # This test avoids the same successor to appear twice if one node # consumes twice the same node. self.successors_[i].append(k) successors_id[i].add(id(k)) for sub in enumerate_subgraphs(v): g = sub[-1] sub_knowns = set() for n in g.input: sub_knowns.add(n.name) for n in g.initializer: sub_knowns.add(n.name) for n in g.sparse_initializer: sub_knowns.add(n.name) for n in g.node: for i in n.input: if i not in sub_knowns: # an input coming from the parent self.used_.add(i) for i in n.output: sub_knowns.add(i) def get_position(self, node: NodeProto) -> int: return self.positions_[id(node)]
[docs] def is_used_by_subgraph(self, name: str) -> bool: """ Tells if a result is used by a subgraphs. """ return name in self.used_
[docs] def is_output(self, name: str) -> bool: """ Tells if a result is an output. """ return name in self.outputs_
[docs] def is_used(self, name: str) -> bool: """ Tells if a result is used or not, including as an output of the graph. """ if name in self.used_: return True if name in self.successors_: return True if name in self.set_output_names_: return True return False
[docs] def is_used_more_than_once(self, name: str) -> bool: """ Tells if a result is used more than once in the current graph or in a subgraph or if it is an output. """ if self.is_used_by_subgraph(name): return True if self.is_output(name): return True suc = self.successors_[name] return len(suc) > 1
[docs] def is_used_only_by(self, name, *nodes: List[NodeProto]) -> bool: """ Tells if a result is only used by a specific set of nodes. """ next_nodes = self.next_nodes(name) allowed = set(id(n) for n in nodes) return all(id(n) in allowed for n in next_nodes)
[docs] def is_constant(self, name: str) -> bool: """ Tells if a result is a constant. """ return self.builder.is_constant(name)
[docs] def is_constant_scalar( self, name: str, value: Optional[Any] = None, broadcast: bool = False ) -> bool: """ Tells if a constant is a scalar :param name: name :param broadcast: if True, consider 1, [1], [[1]], [[[1]]], ... as scalar as well :param value: value to compare to if specified :return: boolean """ if not self.is_constant(name): return False cst_shape = self.get_constant_shape(name, exc=False) if cst_shape is None: return False if broadcast: if cst_shape != tuple() and set(cst_shape) != {1}: return False elif cst_shape not in (tuple(), (1,)): return False cst = self.get_computed_constant(name) assert hasattr(cst, "numpy") or isinstance( cst, np.ndarray ), f"Unexpected type for constant {name}!r, type is {type(cst)}" shape = cst.shape if broadcast: if shape != tuple() and set(shape) != {1}: return False elif shape not in (tuple(), (1,)): return False if value is None: return True if shape == (1,): return all(cst == value) if shape == tuple(): return cst == value assert broadcast, f"Broadcast should be true at this stage, name={name!r}, cst={cst}" return all(cst.reshape((1,)) == value)
[docs] def get_constant_shape(self, name: str, exc: bool = True) -> Optional[Tuple[int, ...]]: """ Returns the shape of a constant. :param name: name :param exc: raises an exception is not possible :return: shape """ if name in self._cache_computed_constant: return self._cache_computed_constant[name].shape if name in self.builder.initializers_dict: proto = self.builder.initializers_dict[name] elif name in self.builder.constants_: proto = self.builder.constants_[name] elif self.is_constant(name): cst = self.get_computed_constant(name) return cst.shape else: if exc: raise AssertionError( f"Unable to retrieve initializer or constant for {name!r}, " f"is_constant={self.is_constant(name)}" ) return None if isinstance(proto, TensorProto): return tuple(proto.dims) if isinstance(proto, NodeProto) and proto.domain == "": if proto.op_type == "Cast": if self.is_constant(proto.output[0]) and not self.is_constant(proto.input[0]): if exc: raise AssertionError( f"Incompatibilities, output is constant " f"when input is not in node {proto}." ) return None return self.get_constant_shape(proto.input[0], exc=exc) if proto.op_type == "Constant": assert ( len(proto.attribute) == 1 ), f"Unexpected number of attribute for node={proto}" for att in proto.attribute: if att.name == "value": return tuple(att.t.dims) if att.name in {"value_float", "value_int"}: return tuple() raise AssertionError( f"Unable to retrieve shape for name={name!r} (type is NodeProto), " f"node.op_type={proto.op_type!r}, " f"attributes={[att.name for att in proto.attribute]}." ) if self.is_constant(name): cst = self.get_computed_constant(name) return None if cst is None else cst.shape if exc: raise AssertionError( f"Unable to retrieve shape for name={name!r} " f"bash and node {proto.op_type!r}" # f"{self.builder.get_debug_msg()}" ) return None if hasattr(proto, "shape"): return proto.shape if exc: raise AssertionError( f"Unable to retrieve shape for name={name!r} and type {type(proto)}" ) return None
[docs] def get_constant_scalar(self, name: str, broadcast: bool = False) -> Union[int, float]: """ Returns a scalar as a constant. :param name: name :param broadcast: consider [1], [[1]], [[[1]]] as constant as well :return: int or float """ cst = self.get_computed_constant(name) assert hasattr(cst, "numpy") or isinstance( cst, np.ndarray ), f"Unexpected type for constant {name}!r, type is {type(cst)}" assert cst.shape == tuple() or ( (broadcast and set(cst.shape) == {1}) or (not broadcast and cst.shape == (1,)) ), f"Unexpected shape {cst.shape} for constant {name!r}" shape = cst.shape if broadcast: value = cst.reshape((1,))[0] else: value = cst[0] if shape == (1,) else cst if value.dtype in { np.float32, np.float16, np.float64, np.dtype("float32"), np.dtype("float16"), np.dtype("float64"), }: return float(value) if value.dtype in { np.complex64, np.complex128, np.dtype("complex64"), np.dtype("complex128"), }: return complex(value) if value.dtype in { self.builder.torch.float32, self.builder.torch.float16, self.builder.torch.float64, self.builder.torch.bfloat16, }: return float(value) if value.dtype in { self.builder.torch.complex64, self.builder.torch.complex128, }: return complex(value) return int(value)
[docs] def get_computed_constant(self, name: str, statistics: Optional[List[str]] = None) -> Any: """ Returns the value for the constant `name`. """ if name in self._cache_computed_constant: value = self._cache_computed_constant[name] else: value = self.builder.get_constant(name, computed_value=True, exc=False) if value is not None: self._cache_computed_constant[name] = value if statistics is None: return value stats = [] for st in statistics: key = name, st if key in self._cache_computed_constant: stat = self._cache_computed_constant[key] else: if st == "min": stat = value.min() elif st == "max": stat = value.max() else: raise RuntimeError(f"Unknown statistics {st!r} for {name!r}.") self._cache_computed_constant[key] = stat stats.append(stat) return stats
[docs] def get_attribute( self, node: NodeProto, att_name: str, exc: bool = True ) -> Optional[AttributeProto]: """ Returns an attribute for a node. """ return self.builder.get_attribute(node, att_name, exc=exc)
[docs] def get_attributes_with_default(self, node: NodeProto, **default_values) -> Dict[str, Any]: """ Returns integer or float values for attributes. """ return self.builder.get_attributes_with_default(node, **default_values)
[docs] def get_axis(self, node: NodeProto, default_axis: Optional[int] = None) -> int: """ Retrieves the axis for many operators. """ att = self.get_attribute(node, "axis", exc=False) if att is None: assert ( default_axis is not None ), f"Node {node.op_type} has no axis and no default value." return default_axis return att.i
[docs] def get_constant_or_attribute( self, node: NodeProto, attribute: str, input_index: int, cvt: Optional[Callable] = None, ) -> Any: """ Returns an input or the value of an attribute. Some attributes became inputs in more recent opsets. The function checks both. :param node: node :param attribute: attribute name :param input_index: input index :param cvt: if not None, called this conversion function before returning the result :return: value """ found = None for att in node.attribute: if att.name == attribute: found = att assert found is None, ( f"get_constant_or_attribute not implemented " f"for attribute={attribute!r} and node={node}." ) assert input_index < len( node.input ), f"Input {input_index} does not exist in node {node}." val = self.get_computed_constant(node.input[input_index]) if cvt is None: return val try: return cvt(val) except (ValueError, TypeError) as e: raise RuntimeError(f"Unable to convert val={val} with cvt={cvt}") from e
[docs] def has_type(self, name: str) -> bool: """ Tells if a result has a type. """ return self.builder.has_type(name)
[docs] def get_type(self, name: str) -> int: """ Returns the type of a result. """ return self.builder.get_type(name)
[docs] def has_rank(self, name: str) -> int: """ Tells if a result has a rank. """ return self.builder.has_rank(name)
[docs] def get_rank(self, name: str) -> int: """ Returns the rank of a result. """ return self.builder.get_rank(name)
[docs] def has_shape(self, name: str) -> bool: """ Tells if a result has a shape. """ return self.builder.has_shape(name)
[docs] def get_shape(self, name: str) -> int: """ Returns the shape of a result. """ return self.builder.get_shape(name)
[docs] def node_before(self, name: str) -> NodeProto: """ Returns the node producing this output. Returns None if it is an input or an initializer. """ if name not in self.predecessors_: return None predecessor = self.predecessors_[name] return self.nodes_[predecessor]
[docs] def next_node(self, name: str) -> NodeProto: """ Returns the next node if it is unique, otherwise fails. """ res = self.next_nodes(name) assert len(res) == 1, f"Unexpected number of successors {len(res)} for {name!r}" return res[0]
[docs] def next_nodes(self, name: str) -> List[NodeProto]: """ Returns the node consuming the given results. """ if name not in self.successors_: return [] return [self.nodes_[i] for i in self.successors_[name]]
[docs] def try_infer_type(self, name: str, exc: bool = False) -> int: """ Tries to infer the type of a result. :param name: name of the result for which to infer the type :param exc: if True, raises an exception if something goes wrong :return: type """ if self.has_type(name): it = self.get_type(name) if exc and it == 0: raise RuntimeError( f"Unable to guess type for {name!r}, " f"knowns types are {pprint.pformat(self.builder._known_types)}" ) return it assert ( name not in self.builder.initializers_dict ), f"name {name!r} has no type but it is an initializer" assert name not in self.builder.input_names, ( f"name {name!r} has no type but it is an input, " f"known_types={pprint.pformat(self.builder._known_types)}" ) node = self.node_before(name) input_types = [(self.get_type(i) if self.has_type(i) else 0) for i in node.input] output_type = infer_types(node, input_types, name, exc=exc) if output_type > 0: return output_type # second try with more depth input_types = [self.try_infer_type(i, exc=exc) for i in node.input] output_type = infer_types(node, input_types, name, exc=exc) if output_type > 0: return output_type # no luck if exc: raise RuntimeError( f"Unable to guess type for {name!r}, " f"knowns types are {pprint.pformat(self.builder._known_types)}" ) return 0
[docs] def try_infer_shape(self, name: str, exc: bool = False) -> int: """ Tries to infer the type of a result. :param name: name of the result for which to infer the type :param exc: if True, raises an exception if something goes wrong :return: type """ if self.has_shape(name): return self.get_shape(name) if exc: raise RuntimeError( f"Unable to guess shape for {name!r}, " f"knowns shapes are {pprint.pformat(self.builder._known_shapes)}" ) return None
@property def main_opset(self): "Returns the opset for the main domain (assuming it is used)." return self.builder.opsets[""] def make_initializer( self, name: str, value: Any, external: bool = False, msg: str = "", source: Optional[str] = None, ) -> str: if not source: if isinstance(value, np.ndarray): if value.dtype == np.int64 and value.size < 16: source = "GraphBuilderPatternOptimization.make_initializer.1/Shape" elif value.size < 2: source = "GraphBuilderPatternOptimization.make_initializer.1/Small" else: source = "GraphBuilderPatternOptimization.make_initializer.0" new_name = self.builder.make_initializer( name, value, external=external, msg=msg, source=source ) return new_name
[docs] def unique_name(self, prefix: str) -> str: "Returns a unique name." return self.builder.unique_name(prefix)
[docs] def make_node_check_opset( self, op_type: str, inputs: Union[str, List[str]], outputs: Union[int, List[str], str] = 1, domain: str = "", attributes: Optional[List[AttributeProto]] = None, name: Optional[str] = None, **kwargs, ): """ Creates a node without adding it to the graph but adapt for some known operators changing over multiple opets. :param op_type: operator type :param inputs: input names :param outputs: outputs names, if one integer, creates n unique names, if str, creates one unique names, if a list, use the name :param domain: node domain :param attributes: list of attributes :param name: node name :param kwargs: other attributes :return: a node """ assert domain == "", f"The method only supports the main domain not {domain!r}" if op_type in {"Squeeze", "Unsqueeze"}: if self.builder.main_opset < 13: assert len(inputs) == 1, f"axis must be given as an attribute for {op_type!r}" return self.make_node( op_type, inputs, outputs, domain=domain, attributes=attributes, name=name, **kwargs, ) if len(inputs) == 1 and "axes" in kwargs: axes = kwargs["axes"] axes_name = self.make_initializer( "", np.array([axes], dtype=np.int64), source="GraphBuilderPatternOptimization.make_node_check_opset.axes", ) inputs.append(axes_name) del kwargs["axes"] return self.make_node( op_type, inputs, outputs, domain=domain, attributes=attributes, name=name, **kwargs, ) raise RuntimeError(f"Operator {op_type!r} not supported yet.")
[docs] def make_node( self, op_type: str, inputs: Union[str, List[str]], outputs: Union[int, List[str], str] = 1, domain: str = "", attributes: Optional[List[AttributeProto]] = None, name: Optional[str] = None, **kwargs, ) -> NodeProto: """ Creates a node without adding it to the graph. :param op_type: operator type :param inputs: input names :param outputs: outputs names, if one integer, creates n unique names, if str, creates one unique names, if a list, use the name :param domain: node domain :param attributes: list of attributes :param name: node name :param kwargs: other attributes :return: a node """ assert name is not None and not name.startswith("None"), ( f"It is good practice to give every node a name so that is " f"easier to see where this node is created but name={name!r} " f"and op_type={op_type!r}." ) name = self.builder.unique_node_name(name) if isinstance(outputs, int): if outputs == 1: outputs = [self.unique_name(f"{op_type.lower()}-{inputs[0]}")] else: outputs = [ self.unique_name(f"{op_type.lower()}-{inputs[0]}-{i}") for i in range(outputs) ] elif isinstance(outputs, str): outputs = [self.unique_name(outputs)] proto = oh.make_node( op_type, inputs, outputs, domain=domain, name=name, **kwargs, ) if all(self.is_constant(i) for i in inputs): for o in outputs: self.builder.update_node_constant(o, proto) proto.doc_string += ":constant-5:" assert len(outputs) == len(set(outputs)) or "" in outputs, ( f"Repeated outputs for node {op_type}({', '.join(inputs)}) -> " f"{', '.join(outputs)}" ) if attributes: proto.attribute.extend(attributes) return proto
[docs] def apply_match(self, match: MatchResult) -> List[NodeProto]: """ Applies one match. Returns the new nodes. """ idn = [id(n) for n in match.nodes if n is not None] assert all(i in self.nodes_ for i in idn), f"One node in {idn} is not referenced" positions = {id(n): i for i, n in enumerate(self.builder.nodes)} assert all(i in positions for i in idn), f"One node in {idn} is not referenced" removed = [positions[i] for i in idn] position_insert = None if match.insert_at is None else positions[id(match.insert_at)] new_nodes = match.apply(self, *match.nodes) if self.verbose >= 10: print(f"[GraphBuilderPatternOptimization.apply_match] {match}") for node in match.nodes: if node is None: continue print(f" - {node.op_type}: {node.input} -> {node.output}") for node in new_nodes: if node is None: continue print(f" + {node.op_type}: {node.input} -> {node.output}") self.builder.insert_and_remove_nodes(position_insert, new_nodes, removed, debug=match) if self.verbose >= 10: print(f"[GraphBuilderPatternOptimization.apply_match] {match} applied.") if self.dump_applied_patterns: self._save_pattern_as_proto(self.dump_applied_patterns, match, new_nodes) return new_nodes
def _to_cstop(self, init: Any, name: Optional[str] = None) -> NodeProto: if isinstance(init, NodeProto): assert ( name is None or init.output[0] == name ), f"Name mismatch {name!r} != {init.output[0]!r}" return init if isinstance(init, TensorProto): assert ( name is None or init.name == name ), f"Name mismatch {name!r} != {init.name!r}" return oh.make_node("Constant", [], [init.name], value=init) if isinstance(init, np.ndarray): return self._to_cstop(onh.from_array(init, name=name)) import torch if isinstance(init, torch.Tensor): return self._to_cstop(init.detach().cpu().numpy(), name=name) raise AssertionError(f"Unexpected type {type(init)}") def _save_pattern_as_proto( self, folder: str, match: MatchResult, new_nodes: List[NodeProto] ): assert isinstance(folder, str), f"Unexpected type {type(folder)} for folder." if folder and not os.path.exists(folder): os.makedirs(folder) name = f"{match.pattern.__class__.__name__}_0.onnx" fullname = os.path.join(folder, name) n = 0 while os.path.exists(fullname): n += 1 name = f"{match.pattern.__class__.__name__}_{n}.onnx" fullname = os.path.join(folder, name) if self.verbose >= 10: print( f"[GraphBuilderPatternOptimization._save_pattern_as_proto] save {fullname!r}" ) unique_names = set() for node in match.nodes: if node is None: continue unique_names |= set(node.input) unique_names |= set(node.output) new_initializers = {} input_names = set() output_names = set() for node in new_nodes: if node is None: continue for i in node.input: if i in unique_names: input_names.add(i) elif i in self.builder.initializers_dict: new_initializers[i] = self.builder.initializers_dict[i] for o in node.output: if o in unique_names: output_names.add(o) old_initializers = {} for node in match.nodes: if node is None: continue for i in node.input: if i in self.builder.initializers_dict: old_initializers[i] = self.builder.initializers_dict[i] new_init_nodes = [self._to_cstop(v, name=k) for k, v in new_initializers.items()] old_init_nodes = [self._to_cstop(v, name=k) for k, v in old_initializers.items()] fproto = oh.make_function( domain="pattern", fname=match.pattern.__class__.__name__, inputs=list(input_names), outputs=list(output_names), nodes=old_init_nodes + [n for n in match.nodes if n is not None], opset_imports=[oh.make_opsetid(k, v) for k, v in self.builder.opsets.items()], ) fproto_apply = oh.make_function( "pattern", match.pattern.__class__.__name__, list(input_names), list(output_names), new_init_nodes + [n for n in new_nodes if n is not None], opset_imports=[oh.make_opsetid(k, v) for k, v in self.builder.opsets.items()], ) def _sh(n): if self.builder.has_shape(n): return self.builder.get_shape(n) if self.builder.has_rank(n): return [None] * self.builder.get_rank(n) return None inputs = [ oh.make_tensor_value_info(n, self.builder.get_type(n), _sh(n)) for n in fproto.input ] outputs = [ oh.make_tensor_value_info(n, self.builder.get_type(n), _sh(n)) for n in fproto.output ] model = oh.make_model( oh.make_graph(fproto.node, "pattern", inputs, outputs), opset_imports=fproto.opset_import, ) model_apply = oh.make_model( oh.make_graph(fproto_apply.node, "pattern", inputs, outputs), opset_imports=fproto_apply.opset_import, ) if self.builder.ir_version: model.ir_version = self.builder.ir_version model_apply.ir_version = self.builder.ir_version with open(fullname, "wb") as f: f.write(model.SerializeToString()) if self.verbose >= 10: print( f"[GraphBuilderPatternOptimization._save_pattern_as_proto] " f"saved {fullname!r}" ) name = f"{match.pattern.__class__.__name__}_{n}_apply.onnx" fullname = os.path.join(folder, name) with open(fullname, "wb") as f: f.write(model_apply.SerializeToString()) if self.verbose >= 10: print( f"[GraphBuilderPatternOptimization._save_pattern_as_proto] " f"saved {fullname!r}" ) def _chech_graph_verifies(self, node: NodeProto): if ( node.op_type in {"MatMul", "Gemm", "FusedMatMul"} and self.builder.has_shape(node.input[0]) and self.builder.has_shape(node.input[1]) ): sh1 = self.builder.get_shape(node.input[0])[-2:] sh2 = self.builder.get_shape(node.input[1])[-2:] tA = self.builder.get_attribute(node, "transA", exc=False) tB = self.builder.get_attribute(node, "transB", exc=False) tA = 0 if tA is None or tA.i == 0 else 1 tB = 0 if tB is None or tB.i == 0 else 1 if tA: sh1 = (sh1[1], sh1[0]) if tB: sh2 = (sh2[1], sh2[0]) assert type(sh1[-1]) != type(sh2[0]) or sh1[-1] == sh2[0], ( # noqa: E721 f"Node {node.op_type!r}, inputs={node.input}, " f"shape1={self.builder.get_shape(node.input[0])}, " f"shape2={self.builder.get_shape(node.input[1])}, " f"tA={tA}, tB={tB}." ) def _check_graph_verifies_whole(self): onx = self.builder.to_onnx(optimize=False) new_shapes = infer_shapes(onx) for val in new_shapes.graph.value_info: itype = val.type.tensor_type.elem_type shape = tuple( d.dim_param if d.dim_param else d.dim_value for d in val.type.tensor_type.shape.dim ) assert self.builder.has_name(val.name), f"name {val.name!r} is missing" assert ( not self.builder.has_type(val.name) or self.builder.get_type(val.name) == itype ), ( f"Result {val.name!r} has type {itype} but the builder " f"assumes it is {self.builder.get_type(val.name)}" ) assert ( not self.builder.has_shape(val.name) or self.builder.get_shape(val.name) == shape ), ( f"Result {val.name!r} has shape {shape} but the builder " f"assumes it is {self.builder.get_shape(val.name)}" ) # from onnxruntime import InferenceSession # InferenceSession( # onx.SerializeToString(), # providers=["CPUExecutionProvider"], # ) def _check_graph( self, statistics: List[Dict[str, Any]], step: str, iteration: int, code: str, verifies: bool, ): begin = time.perf_counter() assert len(self.builder.nodes) > 0, f"The onnx model is empty (step {step}, no node)" known = set(n.name for n in self.builder.inputs) known |= set(self.builder.initializers_dict) for p, node in enumerate(self.builder.nodes): assert ( node.domain in self.opsets ), f"domain {node.domain!r} is not registered in {self.opsets}" for i in node.input: if i == "": continue if i not in known: after = set() for nn in self.builder.nodes[p:]: after |= set(nn.output) raise AssertionError( f"Unknown input {i!r}, step {step!r} at position {p} " f"in node {node.op_type!r} " f"[{node.name}]: {node.input} -> {node.output}, " f"found after = {i in after}" ) known |= set(node.output) if verifies: self._check_graph_verifies(node) for o in self.builder.outputs: assert o.name in known, f"Unknown output {o.name!r}, step {step!r}" if verifies: self._chech_graph_verifies_whole() statistics.append( dict( pattern=f"check_pattern_{code}{1 if verifies else 0}", time_in=time.perf_counter() - begin, iteration=iteration, ) )
[docs] def do_not_remove(self, node: NodeProto) -> bool: """Tells if a node can be removed.""" return self.builder.do_not_remove(node)
[docs] def optimize( self, max_iter=-1, remove_identity: bool = True, stop_after: int = -1 ) -> List[Dict[str, Any]]: """ Optimizes the based on the given list of patterns. :param max_iter: maximum number of iterations :param remove_identity: remove identity nodes, it is better to keep it True, not doing it might prevent other patterns to find a set of nodes to optimize :param sopt_after: stop after this number of replacements (to debug), -1 not to stop :return: the method returns informations about the applied processes. The algorithm runs multiple iteration until the graph is not evolving or `max_iter` is reached. By default, it is equal to the number of nodes. An iteration is: :: matches = [] builds all successors and predecessors for all patterns P: for all nodes n: r = p.match(n) if r: if no node already scheduled to be rewritten by another match: matches.append(r) for all matches r: apply the match r This algorithm may apply more than one rewriting at each iteration but it guarantees the local structure when applying the rewriting was not altered by another one. """ assert ( not self.recursive ), "GraphBuilderPatternOptimization.optimize does not implement recursivity" continue_optimization = True if max_iter == -1: max_iter = len(self.builder.nodes) priorities = list(sorted(set(p.priority for p in self.patterns))) # noqa: C413 assert priorities, "list of priority is null." if self.verbose > 0: print( f"[GraphBuilderPatternOptimization.optimize] start with " f"{len(self.builder.nodes)} nodes, " f"{len(self.builder.initializers_dict)} initializers, " f"{len(self.patterns)} patterns, priorities={priorities}" ) if self.verbose > 1: for i, (pp, _, pattern) in enumerate( sorted((p.priority, repr(p), p) for p in self.patterns) ): print( f"[GraphBuilderPatternOptimization.optimize] " f"use pattern {i+1:3d}/{len(self.patterns)} - P{pp} - {pattern!r}" ) if self.verbose >= 10: print("--") print(self.builder.pretty_text()) print("--") begin_all = time.perf_counter() statistics = [] self._check_graph(statistics, "-", -1, "0", False) n_applied = 0 last_it = 0 current_priority_index = 0 for it in range(max_iter): if not continue_optimization: break if self.verbose > 0: print( f"[GraphBuilderPatternOptimization.optimize] iteration {it}: " f"{len(self.builder.nodes)} nodes, " f"priority={priorities[current_priority_index]}" ) # detects patterns found = False marked = set() matches = [] durations = {} for pattern in self.patterns: if not continue_optimization: break if pattern.priority > priorities[current_priority_index]: # skipping that pattern continue begin = time.perf_counter() before = len(matches) # loop over the nodes for match in pattern.enumerate_matches(self): # bypass this node if the name contains some specific name fail_match = False for n in match.nodes: if n and self.do_not_remove(n): fail_match = True break if fail_match: continue # checks that a node is not already part of another pattern bypass = False for n in match.nodes: if n is None: continue if id(n) in marked: # a node is already marked for replacements bypass = True break if bypass: if self.verbose >= 9: print( f"[{self.__class__.__name__}.match] OVERLAP " f"match={match} #marked: {len(marked)})" ) continue for n in match.nodes: if n is None: continue marked.add(id(n)) found = True if self.verbose > 2: print(f"[GraphBuilderPatternOptimization.optimize] match={match}") matches.append((pattern, match)) if stop_after > 0 and len(matches) + n_applied >= stop_after: continue_optimization = False if self.verbose > 0: print( f"[GraphBuilderPatternOptimization.optimize] " f"stop after with " f"{len(matches)} as stop_after={stop_after} " f"and n_applied={n_applied}" ) break # matches contains all the matchs d = time.perf_counter() - begin statistics.append( dict( pattern=f"match_{pattern}", iteration=it, instances=len(matches) - before, time_in=d, match_index=len(matches), ) ) durations[pattern.__class__.__name__] = ( durations.get(pattern.__class__.__name__, 0) + d ) if self.verbose > 0 and matches: if durations: rev = max([(v, k) for k, v in durations.items()]) revs = f"{rev[-1]}:{rev[0]:.3f}" if len(matches) == 1: print( f"[GraphBuilderPatternOptimization.optimize] applies " f"{len(matches)} matches, [0]={str(matches[0][-1])} - " f"time={sum(durations.values()):.3f} | max_time={revs}" ) else: print( f"[GraphBuilderPatternOptimization.optimize] applies " f"{len(matches)} matches, {_count(matches)} - " f"time={sum(durations.values()):.3f} | max_time={revs}" ) elif len(matches) == 1: print( f"[GraphBuilderPatternOptimization.optimize] applies " f"{len(matches)} matches, [0]={str(matches[0][-1])}" ) else: print( f"[GraphBuilderPatternOptimization.optimize] applies " f"{len(matches)} matches, {_count(matches)}" ) # applies patterns (they must be disjoined) added_types = set() n_added = 0 n_removed = 0 # loop over patterns for im, (pattern, match) in enumerate(matches): if self.verbose > 3: print( f"[GraphBuilderPatternOptimization.optimize] " f"apply {match.to_string(short=False)}" ) begin = time.perf_counter() added_nodes = self.apply_match(match) added_types |= set(n.op_type for n in added_nodes) if self.verbose > 3: print( f"[GraphBuilderPatternOptimization.optimize] - add " f"{[n.op_type for n in added_nodes]}" ) add = len(added_nodes) added_outputs = set() for n in added_nodes: added_outputs |= set(n.output) rem = len([n for n in match.nodes if n is not None]) removed_outputs = set() for n in match.nodes: if n is None: continue removed_outputs |= set(n.output) full_removed = set(i for i in removed_outputs if i not in added_outputs) for i in full_removed: assert not self.is_output(i), ( f"Output {i!r} must not be removed, added_outputs={added_outputs}," f"removed_outputs={removed_outputs}" ) if self.verbose > 3: print( f"[GraphBuilderPatternOptimization.optimize] done " f"{match}: -{rem} +{add} nodes" ) if full_removed and self.verbose > 4: print( f"[GraphBuilderPatternOptimization.optimize] " f"removed outputs {full_removed}" ) obs = dict( pattern=f"apply_{pattern}", added=add, removed=rem, iteration=it, match_index=im, instances=1, time_in=time.perf_counter() - begin, ) statistics.append(obs) self._check_graph(statistics, str(match), it, "A", self.verifies) n_added += add n_removed += rem n_applied += 1 if self.verbose > 2: print( f"[GraphBuilderPatternOptimization.optimize] done all: " f"-{n_removed} +{n_added} nodes" ) if remove_identity and (it < 3 or "Identity" in added_types): # remove unnecessary identity nodes begin = time.perf_counter() id_removed, id_added = self.builder.remove_identity_nodes() statistics.append( dict( pattern="remove_identity_nodes", iteration=it, added=id_added, removed=id_removed, time_in=time.perf_counter() - begin, ) ) self._check_graph(statistics, "remove_identity", it, "B", self.verifies) # rebuild the graph structure begin = time.perf_counter() self._build() statistics.append( dict( pattern="build_graph_for_pattern", iteration=it, time_in=time.perf_counter() - begin, ) ) # next iteration last_it = it + 1 if not found: # No match, increase the priority. current_priority_index += 1 if current_priority_index >= len(priorities): # There is priority left to explore. continue_optimization = len(matches) > 0 break if self.verbose > 0: print( f"[GraphBuilderPatternOptimization.optimize] increase priority " f"to {priorities[current_priority_index]}" ) if self.verbose > 0: duration = time.perf_counter() - begin_all print( f"[GraphBuilderPatternOptimization.optimize] " f"done after {last_it} iterations with " f"{len(self.builder.nodes)} nodes in {duration:.3f}" ) if self.verbose > 1: msg = self.builder._compile_statistics(statistics) print(msg) return statistics