import inspect
from typing import List, Optional
import numpy as np
import onnx.numpy_helper as onh
from onnx import NodeProto
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class TriMatrixPattern(PatternOptimization):
    """
    Replaces a sequence of nodes creating a triangular matrix
    with operator TriMatrix(...).
    """
[docs]
    def match(
        self,
        g: "GraphBuilderPatternOptimization",  # noqa: F821
        node: NodeProto,
        matched: List[MatchResult],
    ) -> Optional[MatchResult]:
        if not g.has_processor("CUDA"):
            return self.none()
        if node.op_type != "Range" or node.domain != "":
            return self.none()
        if (
            len(node.input) != 3
            or not g.is_constant_scalar(node.input[0])
            or not g.is_constant_scalar(node.input[1])
            or not g.is_constant_scalar(node.input[2])
        ):
            return self.none(node, inspect.currentframe().f_lineno)
        start, limit, delta = [g.get_constant_scalar(i) for i in node.input]
        if start != 0 or delta != 1:
            return self.none(node, inspect.currentframe().f_lineno)
        next_nodes = g.next_nodes(node.output[0])
        if len(next_nodes) != 2:
            return self.none(node, inspect.currentframe().f_lineno)
        types = {n.op_type for n in next_nodes}
        if types != {"Add", "Less"}:
            return self.none(node, inspect.currentframe().f_lineno)
        if next_nodes[0].op_type == "Add":
            add_node, less_node = next_nodes
        else:
            less_node, add_node = next_nodes
        if (
            not g.is_constant_scalar(add_node.input[1])
            or g.get_constant_scalar(add_node.input[1]) != 1
        ):
            return self.none(node, inspect.currentframe().f_lineno)
        resh_node = g.next_nodes(add_node.output[0])
        if len(resh_node) != 1 or resh_node[0].op_type != "Reshape":
            return self.none(node, inspect.currentframe().f_lineno)
        reshape_node = resh_node[0]
        shape = g.get_computed_constant(reshape_node.input[1])
        if shape.tolist() != [limit, 1]:
            return self.none(node, inspect.currentframe().f_lineno)
        if less_node.input != [node.output[0], reshape_node.output[0]]:
            return self.none(node, inspect.currentframe().f_lineno)
        where_node = g.next_nodes(less_node.output[0])
        if len(where_node) != 1 or where_node[0].op_type != "Where":
            return self.none(node, inspect.currentframe().f_lineno)
        where_node = where_node[0]
        if not g.is_constant_scalar(where_node.input[1]):
            return self.none(node, inspect.currentframe().f_lineno)
        cst_node = g.node_before(where_node.input[2])
        if cst_node.op_type != "ConstantOfShape":
            return self.none(node, inspect.currentframe().f_lineno)
        shape = g.get_computed_constant(cst_node.input[0])
        if shape.tolist() != [limit, limit]:
            return self.none(node, inspect.currentframe().f_lineno)
        return MatchResult(
            self,
            [node, add_node, reshape_node, less_node, where_node, cst_node],
            self.apply,
            insert_at=where_node,
        ) 
[docs]
    def apply(
        self,
        g: "GraphBuilder",  # noqa: F821
        range_node: NodeProto,
        add_node: NodeProto,
        reshape_node: NodeProto,
        less_node: NodeProto,
        where_node: NodeProto,
        cst_node: NodeProto,
    ) -> List[NodeProto]:
        cst_upper = onh.to_array(g.get_attribute(cst_node, "value").t)
        dtype = cst_upper.dtype
        cst_lower = np.array([g.get_constant_scalar(where_node.input[1])], dtype=dtype)
        cst_diag = cst_lower
        csts_array = np.hstack([cst_lower, cst_diag, cst_upper]).astype(dtype)
        assert csts_array.shape == (3,), f"Wrong constant array: {csts_array}"
        cst_name = g.make_initializer(
            f"{self.__class__.__name__}--{where_node.name}",
            csts_array,
            source="TriMatrixPattern.apply.cst",
        )
        new_node = g.make_node(
            "TriMatrix",
            [cst_node.input[0], cst_name],
            where_node.output,
            name=f"{self.__class__.__name__}--{where_node.name}",
            domain="onnx_extended.ortops.optim.cuda",
        )
        return [new_node]