Source code for onnx_diagnostic.helpers.graph_helper

import pprint
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import onnx
import onnx.helper as oh


[docs] class GraphRendering: """ Helpers to renders a graph. :param proto: model or graph to render. """ def __init__(self, proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto]): self.proto = proto def __repr__(self) -> str: "usual" return f"{self.__class__.__name__}(<{self.proto.__class__.__name__}>)"
[docs] @classmethod def computation_order( cls, nodes: Sequence[onnx.NodeProto], existing: Optional[List[str]] = None, start: int = 1, ) -> List[int]: """ Returns the soonest a node can be computed, every node can assume all nodes with a lower number exists. Every node with a higher number must wait for the previous one. :param nodes: list of nodes :param existing: existing before any computation starts :param start: lower number :return: computation order """ assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), ( f"This algorithme is not yet implemented if the sequence contains " f"a control flow, types={sorted(set(n.op_type for n in nodes))}" ) number = {e: start - 1 for e in (existing or [])} # noqa: C420 results = [start for _ in nodes] for i_node, node in enumerate(nodes): assert all(i in number for i in node.input), ( f"Missing input in node {i_node} type={node.op_type}: " f"{[i for i in node.input if i not in number]}" ) if node.input: mx = max(number[i] for i in node.input) + 1 results[i_node] = mx else: # A constant mx = max(number.values()) if number else 0 for i in node.output: number[i] = mx return results
[docs] @classmethod def graph_positions( cls, nodes: Sequence[onnx.NodeProto], order: List[int], existing: Optional[List[str]] = None, ) -> List[Tuple[int, int]]: """ Returns positions on a plan for every node in a graph. The function minimizes the number of lines crossing each others. It goes forward, every line is optimized depending on what is below. It could be improved with more iterations. :param nodes: list of nodes :param existing: existing names :param order: computation order returned by :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.computation_order` :return: list of tuple( row, column) """ # initialization min_row = min(order) n_rows = max(order) + 1 names: Dict[str, int] = {} positions = [(min_row, i) for i in range(len(order))] for row in range(min_row, n_rows): indices = [i for i, o in enumerate(order) if o == row] assert indices, f"indices cannot be empty for row={row}, order={order}" ns = [nodes[i] for i in indices] mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns] mix = [(m, i) for i, m in enumerate(mx)] mix.sort() for c, (_m, i) in enumerate(mix): positions[indices[i]] = (row, c) n = nodes[indices[i]] for o in n.output: names[o] = c return positions
[docs] @classmethod def text_positions( cls, nodes: Sequence[onnx.NodeProto], positions: List[Tuple[int, int]] ) -> List[Tuple[int, int]]: """ Returns positions for the nodes assuming it is rendered into text. :param nodes: list of nodes :param positions: positions returned by :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.graph_positions` :return: text positions """ new_positions = [(row * 4, col * 2 + row) for row, col in positions] column_size = {col: 3 for _, col in new_positions} for i, (_row, col) in enumerate(new_positions): size = len(nodes[i].op_type) + 5 column_size[col] = max(column_size[col], size) assert column_size[col] < 200, ( f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, " f"nodes[i].op_type={nodes[i].op_type}" ) # cumulated sort = sorted(column_size.items()) cumul = dict(sort[:1]) results = {sort[0][0]: sort[0][1] // 2} for col, size in sort[1:]: c = max(cumul.values()) cumul[col] = c + size results[col] = c + size // 2 return [(row, results[col]) for row, col in new_positions]
@property def nodes(self) -> List[onnx.NodeProto]: "Returns the list of nodes" return ( self.proto.graph.node if isinstance(self.proto, onnx.ModelProto) else self.proto.node ) @property def start_names(self) -> List[onnx.NodeProto]: "Returns the list of known names, inputs and initializer" graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto input_names = ( list(graph.input) if isinstance(graph, onnx.FunctionProto) else [i.name for i in graph.input] ) init_names = ( [] if isinstance(graph, onnx.FunctionProto) else [ *[i.name for i in graph.initializer], *[i.name for i in graph.sparse_initializer], ] ) return [*input_names, *init_names] @property def input_names(self) -> List[str]: "Returns the list of input names." return ( self.proto.input if isinstance(self.proto, onnx.FunctionProto) else [ i.name for i in ( self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph ).input ] ) @property def output_names(self) -> List[str]: "Returns the list of output names." return ( self.proto.output if isinstance(self.proto, onnx.FunctionProto) else [ i.name for i in ( self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph ).output ] )
[docs] @classmethod def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]: """Builds the list of edges between nodes.""" produced = {} for i, node in enumerate(nodes): for o in node.output: produced[o] = i edges = set() for i, node in enumerate(nodes): for name in node.input: if name in produced: edge = produced[name], i edges.add(edge) return edges
ADD_RULES = { ("┴", "┘"): "┴", ("┴", "└"): "┴", ("┬", "┐"): "┬", ("┬", "┌"): "┬", ("-", "└"): "┴", ("-", "|"): "┼", ("-", "┐"): "┬", ("┐", "-"): "┬", ("┘", "-"): "┴", ("┴", "-"): "┴", ("-", "┘"): "┴", ("┌", "-"): "┬", ("┬", "-"): "┬", ("-", "┌"): "┬", ("|", "-"): "┼", ("└", "-"): "┴", ("|", "└"): "├", ("|", "┘"): "┤", ("┐", "|"): "┤", ("┬", "|"): "┼", ("|", "┐"): "┤", ("|", "┌"): "├", ("├", "-"): "┼", ("└", "|"): "├", ("┤", "┐"): "┤", ("┤", "|"): "┤", ("├", "|"): "├", ("┴", "┌"): "┼", ("┐", "┌"): "┬", ("┌", "┐"): "┬", ("┌", "|"): "┼", ("┴", "┐"): "┼", ("┐", "└"): "┼", ("┬", "┘"): "┼", ("├", "└"): "├", ("┤", "┌"): "┼", ("┘", "|"): "┤", ("┴", "|"): "┼", ("┤", "-"): "┼", ("┘", "└"): "┴", }
[docs] @classmethod def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str): """ Prints inplace a text in a grid. The text is centered. :param grid: grid :param position: position :param text: text to print """ row, col = position begin = col - len(text) // 2 grid[row][begin : begin + len(text)] = list(text)
[docs] def text_edge( cls, grid: List[List[str]], p1: Tuple[int, int], p2: Tuple[int, int], mode: str = "square", ): """ Prints inplace an edge in a grid. The text is centered. :param grid: grid :param p1: first position :param p2: second position :param mode: ``'square'`` is the only supported value """ assert mode == "square", f"mode={mode!r} not supported" assert p1[0] < p2[0], f"Unexpected edge p1={p1}, p2={p2}" assert p1[0] + 2 <= p2[0] - 2, f"Unexpected edge p1={p1}, p2={p2}" # removes this when the algorithm is ready assert 0 <= p1[0] < len(grid) - 3, f"p1={p1}, grid:{len(grid)},{len(grid[0])}" assert 2 <= p2[0] < len(grid) - 1, f"p2={p2}, grid:{len(grid)},{len(grid[0])}" assert ( 0 <= p1[1] < min(len(g) for g in grid) ), f"p1={p1}, sizes={[len(g) for g in grid]}" assert ( 0 <= p2[1] < min(len(g) for g in grid) ), f"p2={p2}, sizes={[len(g) for g in grid]}" def add(s1, s2): assert s2 != " ", f"s1={s1!r}, s2={s2!r}" if s1 == " " or s1 == s2: return s2 if s1 == "┼" or s2 == "┼": return "┼" if (s1, s2) in cls.ADD_RULES: return cls.ADD_RULES[s1, s2] raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',") def place(grid, x, y, symbol): grid[x][y] = add(grid[x][y], symbol) place(grid, p1[0] + 1, p1[1], "|") place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘") if p1[0] + 2 == p2[0] - 2: a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1) for i in range(a, b + 1): place(grid, p1[0] + 2, i, "-") else: middle = (p1[1] + p2[1]) // 2 a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1) for i in range(a, b + 1): place(grid, p1[0] + 2, i, "-") a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1) for i in range(a, b + 1): place(grid, p1[0] + 2, i, "-") place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌") place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘") for i in range(p1[0] + 2 + 1, p2[0] - 2): place(grid, i, middle, "|") place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌") place(grid, p2[0] - 1, p2[1], "|")
[docs] def text_rendering(self, prefix="") -> str: """ Renders a model in text. .. runpython:: :showcode: import textwrap import onnx import onnx.helper as oh from onnx_diagnostic.helpers.graph_helper import GraphRendering TFLOAT = onnx.TensorProto.FLOAT proto = oh.make_model( oh.make_graph( [ oh.make_node("Add", ["X", "Y"], ["xy"]), oh.make_node("Neg", ["Y"], ["ny"]), oh.make_node("Mul", ["xy", "ny"], ["a"]), oh.make_node("Mul", ["a", "Y"], ["Z"]), ], "-nd-", [ oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]), oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]), ], [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=9, ) graph = GraphRendering(proto) text = textwrap.dedent(graph.text_rendering()).strip("\\n") print(text) """ nodes = [ *[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names], *self.nodes, *[oh.make_node(i, [i], ["END"]) for i in self.output_names], ] exist = set(self.start_names) - set(self.input_names) exist |= {"BEGIN"} existing = sorted(exist) order = self.computation_order(nodes, existing) positions = self.graph_positions(nodes, order, existing) text_pos = self.text_positions(nodes, positions) edges = self.build_node_edges(nodes) max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes) assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}" max_row = max(row for row, _ in text_pos) + 2 grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)] for n1, n2 in edges: self.text_edge(grid, text_pos[n1], text_pos[n2]) assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}" for node, pos in zip(nodes, text_pos): self.text_grid(grid, pos, node.op_type) assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}" return "\n".join( f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid] )