Source code for experimental_experiment.xoptim.repeated_optim

import hashlib
from collections import Counter
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from .patterns_api import make_pattern_from_onnx, OnnxEasyPatternOptimization


[docs] def node_type_frequency( onx: Union[Sequence[onnx.NodeProto], onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto], min_freq: int = 2, ) -> Tuple[Dict[Tuple[str, str], int], Dict[Tuple[str, str], int], int, List[Tuple[str, str]]]: """ Computes the frequency of every node type in a list. :param onx: any object containing a sequence of NodeProto :param min_freq: do not consider any frequency below that threshold :return: 4 results, the frequencies of the node types, the frequencies of the frequencies, the most frequent frequency (the estimation of the number of layers), all types having the exact same frequency as the previous one .. note:: This function assumes at least one type of node is present only once in every layer. """ if isinstance(onx, onnx.ModelProto): return node_type_frequency(onx.graph, min_freq=min_freq) h = Counter((node.domain, node.op_type) for node in onx.node) freq = {k: v for k, v in h.items() if v >= min_freq} freq_freq = Counter(freq.values()) freqs = dict(freq_freq) for k, v in freq_freq.items(): for i in range(2, k): if k % i == 0 and i in freq_freq: freqs[i] += k // i * v ret = max((v, k) for k, v in freqs.items()) types = [k for k, v in freq.items() if v == ret[1]] return freq, freqs, ret[1], types
def _serialize_attribute(attribute: Sequence[onnx.AttributeProto]) -> bytes: return b"/".join(a.SerializeToString() for a in attribute) class _GraphPattern: def __init__(self, first_node: int): self.cursor = first_node self.first_node = first_node self.subgraph = set() def add_cursor(self): assert self.cursor >= 0, f"Cannot add a negative cursor ({self.cursor})" assert ( self.cursor not in self.subgraph ), f"Cursor {self.cursor} already added in {self.subgraph}" self.subgraph.add(self.cursor) class _GraphIterator: def __init__(self, graph: "_GraphPatterns", node_index: int): self.graph = graph self.node_index = node_index self.io_index = None self.io_kind = None self.io_name = None self.o_suc = None self.o_suc_index = None def __str__(self) -> str: if self.node_index is None: return "it()" indices = [ self.node_index, self.io_index, "N" if self.io_kind is None else ("I" if self.io_kind else "O"), "." if self.o_suc is None else self.o_suc, self.io_name, ] s = ", ".join(map(str, indices)) return f"it({s})" def next(self): # assumes node.output is never empty node = self.graph.nodes[self.node_index] if self.io_name is None: self.io_index = 0 self.o_suc = 0 self.io_kind = bool(node.input) self.io_name = node.input[0] if self.io_kind else node.output[0] self.o_suc_index = ( None if self.io_kind or not self.graph.successors[self.io_name] else self.graph.successors[self.io_name][self.o_suc] ) else: if self.io_kind: self.io_index += 1 self.o_suc = 0 if self.io_index >= len(node.input): self.io_kind = False self.io_index = 0 self.io_name = ( node.input[self.io_index] if self.io_kind else node.output[self.io_index] ) self.o_suc_index = ( None if not self.graph.successors[self.io_name] else self.graph.successors[self.io_name][self.o_suc] ) else: self.io_name = node.output[self.io_index] if self.io_name not in self.graph.successors: self.io_name = None return False self.o_suc += 1 if self.o_suc < len(self.graph.successors[self.io_name]): self.o_suc_index = self.graph.successors[self.io_name][self.o_suc] return True self.o_suc = 0 self.io_index += 1 if self.io_index >= len(node.output): self.io_name = None self.o_suc_index = None return False self.io_name = node.output[self.io_index] self.o_suc_index = ( None if not self.graph.successors[self.io_name] else self.graph.successors[self.io_name][self.o_suc] ) return True def get_name(self, node_index: int) -> str: node = self.graph.nodes[node_index] if self.io_kind is None: return None if self.io_kind: name = node.input[self.io_index] else: name = node.output[self.io_index] assert ( node_index != self.node_index or name == self.io_name ), f"Inconsistency with node_index={node_index}, name={name!r}, self={self!r}" return name def get_node_index(self, node_index: int) -> int: node = self.graph.nodes[node_index] if self.io_kind is None: return None if self.io_kind: name = node.input[self.io_index] index = self.graph.predecessor.get(name, -1) else: name = node.output[self.io_index] suc = self.graph.successors.get(name, []) if not suc: return -1 # It is tricky here because the order of the successors # is not necessarily the same. if self.o_suc == 0 and len(suc) == 1: # Only one possible. index = suc[self.o_suc] else: assert self.o_suc_index is not None, ( f"Unable to guess the forward node, node_index={node_index}, " f"self={self}, mapped={self.graph.mapped}" ) expected_sig = self.graph.signatures[self.o_suc_index] sigs = {self.graph.signatures[s]: s for s in suc} assert len(sigs) == len(suc), ( f"Unable to distinguish between successors signatures: {sigs}, " f"node_index={node_index}, type is " f"{self.graph.nodes[node_index].op_type!r} " f"name is {self.graph.nodes[node_index].name!r}, self={self}, " f"len(suc)={len(suc)}" ) if expected_sig not in sigs: # Cannot find the expected successor return -1 index = sigs[expected_sig] return index assert node_index != self.node_index or name == self.io_name, ( f"Inconsistency with node_index={node_index}, " f"self.io_index={self.io_index!r}, name={name!r}, self={self!r}" ) return index class _GraphPredecessorSuccessors: def __init__( self, nodes: List[onnx.NodeProto], initializer: Optional[onnx.TensorProto] = None, input_names: Optional[List[str]] = None, ): self.nodes = nodes self.initializer = {init.name: init for init in initializer} if initializer else {} self.input_names = input_names or [] self.build_edges() self.build_all_predecessors() self.build_signatures() def input_names_involved(self, node: onnx.NodeProto) -> List[int]: set_inputs = set(self.input_names) inputs = set() allp = self.all_precessors[node.output[0]] for i in allp: n = self.nodes[i] inputs |= set(n.input) & set_inputs if len(inputs) == len(set_inputs): return "ALL" return sorted(inputs) def make_signature(self, node: onnx.NodeProto) -> str: hash = ( f"H{hashlib.sha256(_serialize_attribute(node.attribute)).hexdigest()[:20]}" if node.attribute else "" ) sigi = [] for i in node.input: if i in self.initializer: cst = self.initializer[i] shape = tuple(cst.dims) if len(shape) <= 1: size = np.prod(shape) if size < 1024: t = onh.to_array(cst).ravel() if t.size < 16: c = ",".join(str(x) for x in t.ravel()) else: c = ",".join(str(x) for x in t.ravel()[:16]) sigi.append(c) else: sigi.append("CC") else: sigi.append("C") else: p = self.predecessor.get(i, -1) if p >= 0: n = self.nodes[p] if len(n.output) > 1: sigi.append(f"{n.op_type}.{list(n.output).index(i)}") else: sigi.append(n.op_type) else: sigi.append("") sigo = [] if len(node.output) == 1 and len(self.successors.get(node.output[0], [])) == 1: suc = self.successors[node.output[0]] n = self.nodes[suc[0]] sigo = [f"{n.op_type}.{list(n.input).index(node.output[0])}"] sig = ( f"{node.domain}/{node.op_type}/{len(node.input)}-{len(node.output)}" f"{hash}<<{'/'.join(sigi)}>>{'/'.join(sigo)}" ) iv = self.input_names_involved(node) if iv != "ALL": sig += f"II{'/'.join(iv)}" return sig def build_edges(self): self.successors: Dict[str, Dict[str, int]] = {} self.predecessor: Dict[str, int] = {} self.signatures: Dict[int, str] = {} self.result_names = set() for node_index, node in enumerate(self.nodes): self.result_names |= set(node.input) | set(node.output) for i in node.output: self.predecessor[i] = node_index for i in node.input: if i not in self.successors: self.successors[i] = [] self.successors[i].append(node_index) def build_signatures(self): for node_index, node in enumerate(self.nodes): sig = self.make_signature(node) self.signatures[node_index] = sig def build_all_predecessors(self): self.all_precessors = {} for k, v in self.predecessor.items(): self.all_precessors[k] = {v} changes = 1 it = 0 while changes and it < len(self.nodes) // 2 + 1: changes = 0 for _k, v in self.all_precessors.items(): addition = set() for index in v: n = self.nodes[index] for i in n.input: if i in self.all_precessors: addition |= self.all_precessors[i] before = len(v) v |= addition if len(v) > before: changes += 1 it += 1
[docs] def make_function_from_nodes( nodes: List[onnx.NodeProto], name: str = "function", domain: str = "repeated" ) -> onnx.FunctionProto: """ Creates a function from a list of nodes. Looks into inputs not created by one of the nodes, looks into unused outputs. Opset versions are all set to one. :param nodes: list of nodes :param name: function name :param domain: domain name :return: function proto """ gr = _GraphPredecessorSuccessors(nodes) domains = sorted(set(n.domain for n in nodes)) inputs = sorted( k for k in gr.result_names if k not in gr.predecessor or gr.predecessor[k] is None ) outputs = sorted(k for k in gr.result_names if k not in gr.successors or not gr.successors[k]) return oh.make_function( name, domain, inputs, outputs, nodes, opset_imports=[oh.make_opsetid(n, 1) for n in domains], )
class _GraphPatterns(_GraphPredecessorSuccessors): def __init__( self, nodes: List[onnx.NodeProto], cursor: Sequence[int], initializer: Optional[onnx.TensorProto] = None, input_names: Optional[List[str]] = None, ): super().__init__(nodes, initializer, input_names) self.change_cursor(cursor) def change_cursor(self, cursor): self.pats = [_GraphPattern(c) for c in cursor] self.current: List[_GraphIterator] = [] self.processed_indices = set() self.mapped: Dict[int : List[int]] = {} def validate_cursor(self, verbose: int = 0): # op_types if any(p.cursor < 0 for p in self.pats): if verbose > 2: print("[_GraphPatterns.validate_cursor] INVALID (-1)") return False # already processed if any(p.cursor in self.processed_indices for p in self.pats): if verbose > 2: print("[_GraphPatterns.validate_cursor] INVALID (processed)") return False nodes = [self.nodes[p.cursor] for p in self.pats] rec = set((n.op_type, len(n.input), len(n.output), len(n.attribute)) for n in nodes) if len(rec) != 1: if verbose > 2: print("[_GraphPatterns.validate_cursor] INVALID (not unique type)") return False n_atts = rec.pop()[-1] if n_atts == 0: if verbose > 2: print("[_GraphPatterns.validate_cursor] VALID (1)") return True # Needs to check attributes base = _serialize_attribute(nodes[0].attribute) for n in nodes[1:]: get = _serialize_attribute(n.attribute) if get != base: if verbose > 2: print("[_GraphPatterns.validate_cursor] INVALID (not the same attribute)") return False if verbose > 2: print("[_GraphPatterns.validate_cursor] VALID (2)") return True def add_cursor(self): bug = set() for pi, p in enumerate(self.pats): assert p.cursor not in bug, ( f"Every cursor pi={pi}, should be different but " f"{[p.cursor for p in self.pats]}" ) bug.add(p.cursor) p.add_cursor() if p.cursor not in self.mapped: self.mapped[p.cursor] = set() for p in self.pats: for pp in self.pats: self.mapped[p.cursor].add(pp.cursor) def apply_path(self, node_index: int) -> int: if not self.current: return node_index for p in self.current: node_index = p.get_node_index(node_index) return node_index def set_cursor(self): bug = set() for pi, p in enumerate(self.pats): p.cursor = self.apply_path(p.first_node) assert p.cursor is not None, ( f"Wonrg cursor for p.first_node={p.first_node} and " f"path={'/'.join(map(str,self.current))}, pi={pi}" ) if p.cursor >= 0: if p.cursor in bug: # This means one input is shared accross multiple patterns. # This cannot be possible. p.cursor = -1 else: bug.add(p.cursor) def next_valid(self): i = self.pats[0].cursor self.current.append(_GraphIterator(self, i)) return self.next_not_valid() def next_not_valid(self): has_next = self.current[-1].next() while not has_next: self.current.pop() if not self.current: return False has_next = self.current[-1].next() self.set_cursor() return True def add_processed_cursor(self): if any(p.cursor == -1 for p in self.pats): return for p in self.pats: if p.cursor != -1: self.processed_indices.add(p.cursor) def process( self, name: str = "RepeatedPattern", verbose: int = 0 ) -> Optional[Tuple[Union[List[int], List[List[int]]], OnnxEasyPatternOptimization]]: """Main function, looks for repeated patterns.""" valid = self.validate_cursor(verbose=verbose) n_iter = 0 while True and n_iter < len(self.nodes): if verbose > 1: node = self.nodes[self.pats[0].cursor] print( f"[_GraphPatterns.process] it={n_iter}: {node.op_type!r}: " f"{','.join(str(p.cursor) for p in self.pats)}" ) if valid: if verbose: node = self.nodes[self.pats[0].cursor] print( f"[_GraphPatterns.process] add node type " f"{node.op_type}({', '.join(node.input)})" ) self.add_cursor() self.add_processed_cursor() is_next = self.next_valid() elif self.current: # let's try the next one. self.add_processed_cursor() is_next = self.next_not_valid() else: is_next = False if not is_next: valid = True break valid = self.validate_cursor(verbose=verbose) n_iter += 1 if self.pats[0].subgraph: indices = sorted(self.pats[0].subgraph) nodes = [self.nodes[i] for i in indices] proto = make_function_from_nodes(nodes, domain="repeated") pattern = make_pattern_from_onnx( name, proto, oh.make_function( "repeated", "pattern", proto.input, proto.output, [ oh.make_node( name, proto.input, proto.output, domain="repeated", ) ], opset_imports=[oh.make_opsetid("repeated", 1)], ), ) return indices, pattern return None
[docs] def find_largest_repeated_pattern( onx: Union[Sequence[onnx.NodeProto], onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto], min_freq: int = 2, verbose: int = 0, all_instances: bool = False, name: str = "RepeatedPattern", ) -> Optional[Tuple[Union[List[int], List[List[int]]], OnnxEasyPatternOptimization]]: """ Finds the largest repeated pattern in a graph. :param onx: any object containing a sequence of NodeProto :param min_freq: do not consider any frequency below that threshold :param verbose: verbosity :param all_instances: if True, returns all instances :param name: class name for the instance of :class:`OnnxEasyPatternOptimization <experimental_experiment.xoptim.patterns_api.OnnxEasyPatternOptimization>` :return: list of node indices in the pattern, the pattern as a subtype of :class:`OnnxEasyPatternOptimization <experimental_experiment.xoptim.patterns_api.OnnxEasyPatternOptimization>` The opset are correct if the input is a ModelProto. """ if isinstance(onx, onnx.ModelProto): res = find_largest_repeated_pattern( onx.graph, min_freq=min_freq, verbose=verbose, all_instances=all_instances, name=name, ) if res is None: return res # Let's adjust the domain. subgraphs, pattern = res ds = {d.domain: d.version for d in onx.opset_import} for d in pattern._match_model.opset_import: d.version = ds[d.domain] return subgraphs, pattern _freq, _freqs, npats, types = node_type_frequency(onx, min_freq) if not types: return None if verbose: print(f"[find_largest_repeated_pattern] number of patterns: {npats}") print(f"[find_largest_repeated_pattern] frequencies of frequencies: {_freqs}") print(f"[find_largest_repeated_pattern] candidates: {types}") # initialization keep = None all_patterns = None nodes = list(onx.node) patterns = None input_names = ( onx.input if isinstance(onx, onnx.FunctionProto) else [i.name for i in onx.input] ) for candidate in types: if verbose: print(f"[find_largest_repeated_pattern] tries: {candidate}") cursor = [] for i, n in enumerate(nodes): if (n.domain, n.op_type) == candidate: cursor.append(i) if patterns is None: patterns = _GraphPatterns( nodes, cursor, initializer=onx.initializer if hasattr(onx, "initializer") else None, input_names=input_names, ) else: patterns.change_cursor(cursor) res = patterns.process(verbose=verbose, name=name) if res is not None: if verbose: print(f"[find_largest_repeated_pattern] found a pattern of length {len(res[0])}") if keep is None or len(keep[0]) < len(res[0]): keep = res all_patterns = [sorted(p.subgraph) for p in patterns.pats] if len(keep[0]) > 1: break elif verbose: print("[find_largest_repeated_pattern] no found pattern") if keep is None: return keep if all_instances: return all_patterns, keep[1] return keep