import os
from typing import Dict, List, Optional, Tuple
import onnx
import onnx.helper as oh
import torch
from ..reference.torch_ops import OpRunKernel, OpRunTensor
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
from .ort_session import InferenceSessionForTorch
_SAVED: List[str] = []
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
def _get_model_name(op_name: str, provider: str) -> Optional[str]:
    if _SAVE_OPTIMIZED_MODEL_:
        name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
        _SAVED.append(name)
        return name
    return None
[docs]
class LayerNormalizationOrt(OpRunKernel):
    "LayerNormalization with onnxruntime"
[docs]
    @classmethod
    def device_dependent(cls) -> bool:
        "Needs device."
        return True 
    def __init__(
        self,
        node: onnx.NodeProto,
        version=None,
        device: Optional[torch.device] = None,
        verbose: int = 0,
    ):
        super().__init__(node, version, verbose=verbose)
        self.axis = self.get_attribute_int(node, "axis", -1)
        self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
        self.device = device
        self.stash_type = onnx_dtype_to_torch_dtype(
            self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)  # type: ignore[arg-type]
        )
        self.compute_std = len(node.output) > 1
        assert not self.compute_std, (
            f"This kernel implementation only work when only one output "
            f"is required but {node.output} were."
        )
        self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
        self.is_cpu = torch.device("cpu") == self.device
    def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
        shape = [*["d{i}" for i in range(rank - 1)], "last"]
        layer_model = oh.make_model(
            oh.make_graph(
                [
                    oh.make_node(
                        "LayerNormalization",
                        ["X", "W", "B"] if has_bias else ["X", "W"],
                        ["Z"],
                        axis=self.axis,
                        epsilon=self.epsilon,
                    )
                ],
                "dummy",
                (
                    [
                        oh.make_tensor_value_info("X", itype, shape),
                        oh.make_tensor_value_info("W", itype, ["last"]),
                        oh.make_tensor_value_info("B", itype, ["last"]),
                    ]
                    if has_bias
                    else [
                        oh.make_tensor_value_info("X", itype, shape),
                        oh.make_tensor_value_info("W", itype, ["last"]),
                    ]
                ),
                [oh.make_tensor_value_info("Z", itype, shape)],
            ),
            ir_version=9,
            opset_imports=[oh.make_opsetid("", 18)],
        )
        provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
        self._provider = provider
        return InferenceSessionForTorch(
            layer_model,
            optimized_model_filepath=_get_model_name("layer_norm", provider),
            providers=[provider],
        )
[docs]
    def run(self, x, scale, bias=None):
        itype = torch_dtype_to_onnx_dtype(x.dtype)
        rank = len(x.shape)
        key = itype, rank
        if key not in self._cache:
            self._cache[key] = self._make_model(itype, rank, bias is not None)
        sess = self._cache[key]
        if self.verbose:
            print(f"[LayerNormalizationOrt] running on {self._provider!r}")
        feeds = dict(X=x.tensor, W=scale.tensor)
        if bias is not None:
            feeds["B"] = bias.tensor
        got = sess.run(None, feeds)[0]
        return OpRunTensor(got) 
 
[docs]
class MatMulOrt(OpRunKernel):
    "MatMul with onnxruntime"
[docs]
    @classmethod
    def device_dependent(cls) -> bool:
        "Needs device."
        return True 
    def __init__(
        self,
        node: onnx.NodeProto,
        version=None,
        device: Optional[torch.device] = None,
        verbose: int = 0,
    ):
        super().__init__(node, version, verbose=verbose)
        self.device = device
        self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
        self.is_cpu = torch.device("cpu") == self.device
    def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
        shapea = ["a{i}" for i in range(ranka)]
        shapeb = ["b{i}" for i in range(rankb)]
        shapec = ["c{i}" for i in range(max(ranka, rankb))]
        model = oh.make_model(
            oh.make_graph(
                [oh.make_node("MatMul", ["A", "B"], ["C"])],
                "dummy",
                [
                    oh.make_tensor_value_info("A", itype, shapea),
                    oh.make_tensor_value_info("B", itype, shapeb),
                ],
                [oh.make_tensor_value_info("C", itype, shapec)],
            ),
            ir_version=9,
            opset_imports=[oh.make_opsetid("", 18)],
        )
        provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
        self._provider = provider
        return InferenceSessionForTorch(
            model,
            optimized_model_filepath=_get_model_name("matmul", provider),
            providers=[provider],
        )
[docs]
    def run(self, a, b):
        itype = torch_dtype_to_onnx_dtype(a.dtype)
        ranka, rankb = len(a.shape), len(b.shape)
        key = itype, ranka, rankb
        if key not in self._cache:
            self._cache[key] = self._make_model(itype, ranka, rankb)
        sess = self._cache[key]
        if self.verbose:
            print(f"[MatMulOrt] running on {self._provider!r}")
        feeds = dict(A=a.tensor, B=b.tensor)
        got = sess.run(None, feeds)[0]
        return OpRunTensor(got)