import pprint
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import onnx
import onnx.helper as oh
[docs]
class GraphRendering:
    """
    Helpers to renders a graph.
    :param proto: model or graph to render.
    """
    def __init__(self, proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto]):
        self.proto = proto
    def __repr__(self) -> str:
        "usual"
        return f"{self.__class__.__name__}(<{self.proto.__class__.__name__}>)"
[docs]
    @classmethod
    def computation_order(
        cls,
        nodes: Sequence[onnx.NodeProto],
        existing: Optional[List[str]] = None,
        start: int = 1,
    ) -> List[int]:
        """
        Returns the soonest a node can be computed,
        every node can assume all nodes with a lower number exists.
        Every node with a higher number must wait for the previous one.
        :param nodes: list of nodes
        :param existing: existing before any computation starts
        :param start: lower number
        :return: computation order
        """
        assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
            f"This algorithme is not yet implemented if the sequence contains "
            f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
        )
        number = {e: start - 1 for e in (existing or [])}  # noqa: C420
        results = [start for _ in nodes]
        for i_node, node in enumerate(nodes):
            assert all(i in number for i in node.input), (
                f"Missing input in node {i_node} type={node.op_type}: "
                f"{[i for i in node.input if i not in number]}"
            )
            if node.input:
                mx = max(number[i] for i in node.input) + 1
                results[i_node] = mx
            else:
                # A constant
                mx = max(number.values()) if number else 0
            for i in node.output:
                number[i] = mx
        return results 
[docs]
    @classmethod
    def graph_positions(
        cls,
        nodes: Sequence[onnx.NodeProto],
        order: List[int],
        existing: Optional[List[str]] = None,
    ) -> List[Tuple[int, int]]:
        """
        Returns positions on a plan for every node in a graph.
        The function minimizes the number of lines crossing each others.
        It goes forward, every line is optimized depending on what is below.
        It could be improved with more iterations.
        :param nodes: list of nodes
        :param existing: existing names
        :param order: computation order returned by
            :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.computation_order`
        :return: list of tuple( row, column)
        """
        # initialization
        min_row = min(order)
        n_rows = max(order) + 1
        names: Dict[str, int] = {}
        positions = [(min_row, i) for i in range(len(order))]
        for row in range(min_row, n_rows):
            indices = [i for i, o in enumerate(order) if o == row]
            assert indices, f"indices cannot be empty for row={row}, order={order}"
            ns = [nodes[i] for i in indices]
            mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns]
            mix = [(m, i) for i, m in enumerate(mx)]
            mix.sort()
            for c, (_m, i) in enumerate(mix):
                positions[indices[i]] = (row, c)
                n = nodes[indices[i]]
                for o in n.output:
                    names[o] = c
        return positions 
[docs]
    @classmethod
    def text_positions(
        cls, nodes: Sequence[onnx.NodeProto], positions: List[Tuple[int, int]]
    ) -> List[Tuple[int, int]]:
        """
        Returns positions for the nodes assuming it is rendered into text.
        :param nodes: list of nodes
        :param positions: positions returned by
            :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.graph_positions`
        :return: text positions
        """
        new_positions = [(row * 4, col * 2 + row) for row, col in positions]
        column_size = {col: 3 for _, col in new_positions}
        for i, (_row, col) in enumerate(new_positions):
            size = len(nodes[i].op_type) + 5
            column_size[col] = max(column_size[col], size)
            assert column_size[col] < 200, (
                f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, "
                f"nodes[i].op_type={nodes[i].op_type}"
            )
        # cumulated
        sort = sorted(column_size.items())
        cumul = dict(sort[:1])
        results = {sort[0][0]: sort[0][1] // 2}
        for col, size in sort[1:]:
            c = max(cumul.values())
            cumul[col] = c + size
            results[col] = c + size // 2
        return [(row, results[col]) for row, col in new_positions] 
    @property
    def nodes(self) -> List[onnx.NodeProto]:
        "Returns the list of nodes"
        return (
            self.proto.graph.node
            if isinstance(self.proto, onnx.ModelProto)
            else self.proto.node
        )
    @property
    def start_names(self) -> List[onnx.NodeProto]:
        "Returns the list of known names, inputs and initializer"
        graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
        input_names = (
            list(graph.input)
            if isinstance(graph, onnx.FunctionProto)
            else [i.name for i in graph.input]
        )
        init_names = (
            []
            if isinstance(graph, onnx.FunctionProto)
            else [
                *[i.name for i in graph.initializer],
                *[i.name for i in graph.sparse_initializer],
            ]
        )
        return [*input_names, *init_names]
    @property
    def input_names(self) -> List[str]:
        "Returns the list of input names."
        return (
            self.proto.input
            if isinstance(self.proto, onnx.FunctionProto)
            else [
                i.name
                for i in (
                    self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
                ).input
            ]
        )
    @property
    def output_names(self) -> List[str]:
        "Returns the list of output names."
        return (
            self.proto.output
            if isinstance(self.proto, onnx.FunctionProto)
            else [
                i.name
                for i in (
                    self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
                ).output
            ]
        )
[docs]
    @classmethod
    def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]:
        """Builds the list of edges between nodes."""
        produced = {}
        for i, node in enumerate(nodes):
            for o in node.output:
                produced[o] = i
        edges = set()
        for i, node in enumerate(nodes):
            for name in node.input:
                if name in produced:
                    edge = produced[name], i
                    edges.add(edge)
        return edges 
    ADD_RULES = {
        ("┴", "┘"): "┴",
        ("┴", "└"): "┴",
        ("┬", "┐"): "┬",
        ("┬", "┌"): "┬",
        ("-", "└"): "┴",
        ("-", "|"): "┼",
        ("-", "┐"): "┬",
        ("┐", "-"): "┬",
        ("┘", "-"): "┴",
        ("┴", "-"): "┴",
        ("-", "┘"): "┴",
        ("┌", "-"): "┬",
        ("┬", "-"): "┬",
        ("-", "┌"): "┬",
        ("|", "-"): "┼",
        ("└", "-"): "┴",
        ("|", "└"): "├",
        ("|", "┘"): "┤",
        ("┐", "|"): "┤",
        ("┬", "|"): "┼",
        ("|", "┐"): "┤",
        ("|", "┌"): "├",
        ("├", "-"): "┼",
        ("└", "|"): "├",
        ("┤", "┐"): "┤",
        ("┤", "|"): "┤",
        ("├", "|"): "├",
        ("┴", "┌"): "┼",
        ("┐", "┌"): "┬",
        ("┌", "┐"): "┬",
        ("┌", "|"): "┼",
        ("┴", "┐"): "┼",
        ("┐", "└"): "┼",
        ("┬", "┘"): "┼",
        ("├", "└"): "├",
        ("┤", "┌"): "┼",
        ("┘", "|"): "┤",
        ("┴", "|"): "┼",
        ("┤", "-"): "┼",
        ("┘", "└"): "┴",
    }
[docs]
    @classmethod
    def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str):
        """
        Prints inplace a text in a grid. The text is centered.
        :param grid: grid
        :param position: position
        :param text: text to print
        """
        row, col = position
        begin = col - len(text) // 2
        grid[row][begin : begin + len(text)] = list(text) 
[docs]
    def text_edge(
        cls,
        grid: List[List[str]],
        p1: Tuple[int, int],
        p2: Tuple[int, int],
        mode: str = "square",
    ):
        """
        Prints inplace an edge in a grid. The text is centered.
        :param grid: grid
        :param p1: first position
        :param p2: second position
        :param mode: ``'square'`` is the only supported value
        """
        assert mode == "square", f"mode={mode!r} not supported"
        assert p1[0] < p2[0], f"Unexpected edge p1={p1}, p2={p2}"
        assert p1[0] + 2 <= p2[0] - 2, f"Unexpected edge p1={p1}, p2={p2}"
        # removes this when the algorithm is ready
        assert 0 <= p1[0] < len(grid) - 3, f"p1={p1}, grid:{len(grid)},{len(grid[0])}"
        assert 2 <= p2[0] < len(grid) - 1, f"p2={p2}, grid:{len(grid)},{len(grid[0])}"
        assert (
            0 <= p1[1] < min(len(g) for g in grid)
        ), f"p1={p1}, sizes={[len(g) for g in grid]}"
        assert (
            0 <= p2[1] < min(len(g) for g in grid)
        ), f"p2={p2}, sizes={[len(g) for g in grid]}"
        def add(s1, s2):
            assert s2 != " ", f"s1={s1!r}, s2={s2!r}"
            if s1 == " " or s1 == s2:
                return s2
            if s1 == "┼" or s2 == "┼":
                return "┼"
            if (s1, s2) in cls.ADD_RULES:
                return cls.ADD_RULES[s1, s2]
            raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',")
        def place(grid, x, y, symbol):
            grid[x][y] = add(grid[x][y], symbol)
        place(grid, p1[0] + 1, p1[1], "|")
        place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘")
        if p1[0] + 2 == p2[0] - 2:
            a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1)
            for i in range(a, b + 1):
                place(grid, p1[0] + 2, i, "-")
        else:
            middle = (p1[1] + p2[1]) // 2
            a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
            for i in range(a, b + 1):
                place(grid, p1[0] + 2, i, "-")
            a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
            for i in range(a, b + 1):
                place(grid, p1[0] + 2, i, "-")
            place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌")
            place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘")
            for i in range(p1[0] + 2 + 1, p2[0] - 2):
                place(grid, i, middle, "|")
        place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌")
        place(grid, p2[0] - 1, p2[1], "|") 
[docs]
    def text_rendering(self, prefix="") -> str:
        """
        Renders a model in text.
        .. runpython::
            :showcode:
            import textwrap
            import onnx
            import onnx.helper as oh
            from onnx_diagnostic.helpers.graph_helper import GraphRendering
            TFLOAT = onnx.TensorProto.FLOAT
            proto = oh.make_model(
                oh.make_graph(
                    [
                        oh.make_node("Add", ["X", "Y"], ["xy"]),
                        oh.make_node("Neg", ["Y"], ["ny"]),
                        oh.make_node("Mul", ["xy", "ny"], ["a"]),
                        oh.make_node("Mul", ["a", "Y"], ["Z"]),
                    ],
                    "-nd-",
                    [
                        oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
                        oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
                    ],
                    [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
                ),
                opset_imports=[oh.make_opsetid("", 18)],
                ir_version=9,
            )
            graph = GraphRendering(proto)
            text = textwrap.dedent(graph.text_rendering()).strip("\\n")
            print(text)
        """
        nodes = [
            *[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names],
            *self.nodes,
            *[oh.make_node(i, [i], ["END"]) for i in self.output_names],
        ]
        exist = set(self.start_names) - set(self.input_names)
        exist |= {"BEGIN"}
        existing = sorted(exist)
        order = self.computation_order(nodes, existing)
        positions = self.graph_positions(nodes, order, existing)
        text_pos = self.text_positions(nodes, positions)
        edges = self.build_node_edges(nodes)
        max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes)
        assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}"
        max_row = max(row for row, _ in text_pos) + 2
        grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)]
        for n1, n2 in edges:
            self.text_edge(grid, text_pos[n1], text_pos[n2])
            assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
        for node, pos in zip(nodes, text_pos):
            self.text_grid(grid, pos, node.op_type)
            assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
        return "\n".join(
            f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid]
        )