import functools
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import onnx
import torch
from ..helpers.torch_helper import to_tensor
from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
from .report_results_comparison import ReportResultComparison
from . import torch_ops
[docs]
@functools.lru_cache
def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRunKernel]]:
    """
    Retrieves all the available kernels class :class:`TorchOnnxEvaluator`
    can use. The full list is the following.
    .. runpython::
        :showcode:
        from onnx_diagnostic.reference.torch_evaluator import get_kernels
        for k, v in sorted(get_kernels().items()):
            domain, name, version = k
            f = f"{name}({version})" if domain == "" else f"{name}[{domain}]({version})"
            add = " " * max(25 - len(f), 0)
            dd = " -- device dependent" if v.device_dependent() else ""
            print(f"{f}{add} -- {v.__name__}{dd}")
    """
    res = {}
    for _k, v in torch_ops.__dict__.items():
        if isinstance(v, type) and issubclass(v, torch_ops.OpRunKernel) and "_" in v.__name__:
            name, version = v.__name__.split("_")
            domain = getattr(v, "domain", "")
            res[domain, name, int(version)] = v
    return res 
[docs]
class TorchOnnxEvaluator:
    """
    Torch evaluator for onnx models.
    The model does not stores the original proto it evaluates to avoid
    :param proto: a proto
    :param providers: where to run the model
    :param opsets: needed if proto is a graph
    :param functions: known local functions
    :param verbose: verbosity level
    :param custom_kernels: dictionary of kernels the user can defined to overwrite
        a specific implementation: ``("", "LayerNormalization"): CustomKernel``
    The class holds the following attributes:
    * `providers`: providers
    * `default_device`: default torch device
    * `constants`: all initializers or constants
    * `kernels`: kernels
    * `runtime_info`: produced by :func:`first_used_last_used
      <onnx_diagnostic.torch_onnx.runtime_info.first_used_last_used>`
    * `last_used`: contains the list of intermediate results,
       to remove after every node execution,
       this avoid the memory to grow too much
    * `functions`: local functions
    The class is not multithreaded. `runtime_info` gets updated
    by the the class. The list of available kernels is returned by function
    :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
    Example:
    .. runpython::
        :showcode:
        import onnx
        import onnx.helper as oh
        import torch
        from onnx_diagnostic.helpers import string_type
        from onnx_diagnostic.reference import TorchOnnxEvaluator
        TFLOAT = onnx.TensorProto.FLOAT
        proto = oh.make_model(
            oh.make_graph(
                [
                    oh.make_node("Sigmoid", ["Y"], ["sy"]),
                    oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
                    oh.make_node("Mul", ["X", "ysy"], ["final"]),
                ],
                "-nd-",
                [
                    oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
                    oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
                ],
                [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
            ),
            opset_imports=[oh.make_opsetid("", 18)],
            ir_version=9,
        )
        sess = TorchOnnxEvaluator(proto)
        feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
        result = sess.run(None, feeds)
        print(string_type(result, with_shape=True, with_min_max=True))
    With ``verbose=1``, the class prints out every kernel run and
    and every result deleted along the run.
    It shows when a result is not needed anymore. In that case,
    it is deleted to free the memory it takes.
    .. runpython::
        :showcode:
        import onnx
        import onnx.helper as oh
        import torch
        from onnx_diagnostic.helpers import string_type
        from onnx_diagnostic.reference import TorchOnnxEvaluator
        TFLOAT = onnx.TensorProto.FLOAT
        proto = oh.make_model(
            oh.make_graph(
                [
                    oh.make_node("Sigmoid", ["Y"], ["sy"]),
                    oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
                    oh.make_node("Mul", ["X", "ysy"], ["final"]),
                ],
                "-nd-",
                [
                    oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
                    oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
                ],
                [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
            ),
            opset_imports=[oh.make_opsetid("", 18)],
            ir_version=9,
        )
        sess = TorchOnnxEvaluator(proto, verbose=1)
        feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
        result = sess.run(None, feeds)
        print(string_type(result, with_shape=True, with_min_max=True))
    The runtime can also execute the kernel the onnx model on CUDA.
    It follows the same logic as :class:`onnxruntime.InferenceSession`:
    ``providers=["CUDAExecutionProvider"]``.
    It is better in that case to move the input on CUDA. The class
    tries to move every weight on CUDA but tries to keep any tensor
    identified as a shape in CPU. Some bugs may remain as torch
    raises an exception when devices are expected to be the same.
    The runtime was validated with model :epkg:`arnir0/Tiny-LLM`.
    Next example shows how to replace a kernel with a different
    one based on :epkg:`onnxruntime`.
    .. runpython::
        :showcode:
        import numpy as np
        import onnx
        import onnx.helper as oh
        import onnxruntime
        import torch
        from onnx_diagnostic.helpers import string_type
        from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
        from onnx_diagnostic.reference import TorchOnnxEvaluator
        from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
        TFLOAT16 = onnx.TensorProto.FLOAT16
        class LayerNormalizationOrt(OpRunKernel):
            "LayerNormalization based on onnxruntime"
            def __init__(self, node: onnx.NodeProto, version=None, verbose=0):
                super().__init__(node, version, verbose=verbose)
                self.axis = self.get_attribute_int(node, "axis", -1)
                self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
                self.stash_type = onnx_dtype_to_torch_dtype(
                    self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
                )
                self.compute_std = len(node.output) > 1
                assert not self.compute_std, "The keren only computes the first output."
                layer_model = oh.make_model(
                    oh.make_graph(
                        [
                            oh.make_node(
                                "LayerNormalization",
                                ["X", "W", "B"],
                                ["Z"],
                                axis=-1,
                                epsilon=9.999999974752427e-7,
                            )
                        ],
                        "dummy",
                        [
                            oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
                            oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
                            oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
                        ],
                        [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
                    ),
                    ir_version=9,
                    opset_imports=[oh.make_opsetid("", 17)],
                )
                self.ort_sess = onnxruntime.InferenceSession(
                    layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
                )
            def run(self, x, scale, bias=None):
                print(f"-- running {self.__class__.__name__}")
                feeds = dict(X=x, W=scale)
                if bias is not None:
                    feeds["B"] = bias
                feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
                got = self.ort_sess.run(None, feeds)[0]
                return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
        # This kernel is tested on this model.
        model = oh.make_model(
            oh.make_graph(
                [
                    oh.make_node(
                        "LayerNormalization",
                        ["X", "W", "B"],
                        ["ln"],
                        axis=-1,
                        epsilon=9.999999974752427e-7,
                    ),
                    oh.make_node(
                        "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
                    ),
                ],
                "dummy",
                [
                    oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
                    oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
                    oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
                ],
                [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
            ),
            ir_version=9,
            opset_imports=[oh.make_opsetid("", 17)],
        )
        torch_sess = TorchOnnxEvaluator(
            model,
            custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
            verbose=1,
        )
        feeds = dict(
            zip(
                torch_sess.input_names,
                [
                    torch.rand(3, 4, 5, dtype=torch.float16),
                    torch.abs(torch.rand(5, dtype=torch.float16)),
                    torch.rand(5, dtype=torch.float16),
                ],
            )
        )
        res = torch_sess.run(None, feeds)
        print(string_type(res, with_shape=True, with_min_max=True))
    """
[docs]
    class IO:
        "IO"
        def __init__(self, name: str, type: int, shape: Tuple[Union[str, int], ...]):
            self.name = name
            self.type = type
            self.shape = shape 
    @classmethod
    def _on_cuda(cls, providers) -> int:
        if not providers:
            return -1
        for p in providers:
            if p == "CUDAExecutionProvider":
                return 0
            if isinstance(p, tuple) and p[0] == "CUDAExecutionProvider":
                return p[1]["device_id"]
        return -1
    def __init__(
        self,
        proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
        providers: Tuple[str, ...] = ("CPUExecutionProvider",),
        opsets: Optional[Dict[str, int]] = None,
        local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
        verbose: int = 0,
        custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRunKernel]]] = None,
    ):
        self.providers = providers
        self.constants: Dict[str, torch.Tensor] = {}
        self.kernels: List[Optional[torch_ops.OpRunKernel]] = []
        self.functions = local_functions.copy() if local_functions else {}
        self.CPU = torch.tensor([0]).to("cpu").device
        self.verbose = verbose
        self.custom_kernels = custom_kernels or {}
        dev = self._on_cuda(providers)
        if dev < 0:
            self.default_device = self.CPU
            self.CUDA = None
        else:
            self.CUDA = torch.tensor([0]).to(f"cuda:{dev}").device
            self.default_device = self.CUDA
        if isinstance(proto, str):
            proto = onnx.load(proto)
        if isinstance(proto, onnx.ModelProto):
            assert opsets is None, "proto is a model, opsets must be None in that case"
            assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
            self.opsets = {d.domain: d.version for d in proto.opset_import}
            for f in proto.functions:
                self.functions[f.domain, f.name] = self.__class__(
                    f,
                    providers=providers,
                    local_functions=self.functions,
                    verbose=self.verbose,
                )
            self._build_initializers(proto.graph.initializer)
            self._build_initializers(proto.graph.node)
            self._build_kernels(proto.graph.node)
            self.input_names = [i.name for i in proto.graph.input]
            self.output_names = [i.name for i in proto.graph.output]
            self._io_input_names = [
                self.IO(
                    name=i.name,
                    type=i.type.tensor_type.elem_type,
                    shape=tuple(
                        d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
                    ),
                )
                for i in proto.graph.input
            ]
            self._io_output_names = [
                self.IO(
                    name=i.name,
                    type=i.type.tensor_type.elem_type,
                    shape=tuple(
                        d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
                    ),
                )
                for i in proto.graph.output
            ]
        elif isinstance(proto, onnx.GraphProto):
            assert opsets, "opsets must be specified if proto is a graph"
            assert not proto.sparse_initializer, "sparse_initializer not support yet"
            self.opsets = opsets
            self._build_initializers(proto.initializer)
            self._build_initializers(proto.node)
            self._build_kernels(proto.node)
            self.input_names = [i.name for i in proto.input]
            self.output_names = [i.name for i in proto.output]
        elif isinstance(proto, onnx.FunctionProto):
            assert opsets is None, "proto is a model, opsets must be None in that case"
            self.opsets = {d.domain: d.version for d in proto.opset_import}
            self._build_initializers(proto.node)
            self._build_kernels(proto.node)
            self.input_names = list(proto.input)
            self.output_names = list(proto.output)
        else:
            raise TypeError(f"Unexpected type {type(proto)} for proto")
        self.runtime_info = first_used_last_used(proto, constant_as_initializer=True)
        self.last_used: List[List[str]] = [[] for _ in self.kernels]
        for name, info in self.runtime_info.items():
            assert isinstance(info.last_used, int) or info.is_input, (
                f"Missing field last_used in {info!r}, last_used={info.last_used!r}, "
                f"This may mean the node is unused and it should be removed."
            )
            if info.last_used is None:
                # Not used.
                self.last_used[0].append(name)
            elif not info.is_output and not info.is_initializer:
                self.last_used[info.last_used].append(name)
[docs]
    def get_outputs(self):
        "Same API than onnxruntime."
        assert hasattr(self, "_io_output_names"), "Missing attribute '_io_output_names'."
        return self._io_output_names 
    @property
    def on_cuda(self) -> bool:
        "Tells if the default device is CUDA."
        return self.default_device == self.CUDA
    def _build_initializers(self, inits: Sequence[Union[onnx.NodeProto, onnx.TensorProto]]):
        for init in inits:
            if isinstance(init, onnx.TensorProto):
                self.constants[init.name] = to_tensor(init).to(self.default_device)
            elif (
                isinstance(init, onnx.NodeProto)
                and init.op_type == "Constant"
                and init.domain == ""
            ):
                value = None
                for att in init.attribute:
                    if att.name == "value":
                        value = to_tensor(att.t).to(self.default_device)
                    elif att.name == "value_floats":
                        value = torch.tensor(list(att.floats), dtype=torch.float32).to(
                            self.default_device
                        )
                assert value is not None, f"No attribute value in node {init}"
                self.constants[init.output[0]] = value
    def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
        kernels = get_kernels()
        self.kernels.clear()
        for node in nodes:
            kernel_kwargs = dict(verbose=max(0, self.verbose - 1))
            opset = self.opsets[node.domain]
            key = node.domain, node.op_type, opset
            if key[:2] in self.custom_kernels:
                cls = self.custom_kernels[key[:2]]
                ags = [self.default_device] if cls.device_dependent() else []
                kws = dict(parent=self) if cls.has_subgraphs() else {}
                kws.update(kernel_kwargs)  # type: ignore[arg-type]
                kernel2 = cls(node, opset, *ags, **kws)  # type: ignore[arg-type]
                self.kernels.append(kernel2)
                continue
            if (node.domain, node.op_type) in self.functions:
                kernel = torch_ops.OpRunFunction(
                    self.functions[node.domain, node.op_type],
                    node,
                    self.opsets[node.domain],
                    **kernel_kwargs,
                )
                self.kernels.append(kernel)
                continue
            if node.op_type == "Constant" and node.domain == "":
                # Treated as a constant.
                self.kernels.append(None)
                continue
            while key not in kernels and opset > 0:
                opset -= 1
                key = node.domain, node.op_type, opset
            assert key in kernels, (
                f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}, "
                f"local functions={sorted(self.functions)}"
            )
            cls = kernels[key]
            ags = [self.default_device] if cls.device_dependent() else []
            kws = dict(parent=self) if cls.has_subgraphs() else {}
            kws.update(kernel_kwargs)  # type: ignore[arg-type]
            kernel2 = cls(node, opset, *ags, **kws)  # type: ignore[arg-type]
            self.kernels.append(kernel2)
[docs]
    def run(
        self,
        outputs: Optional[List[str]],
        feeds: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray]],
        report_cmp: Optional[ReportResultComparison] = None,
    ) -> Union[List[Optional[torch.Tensor]], List[Optional[np.ndarray]]]:
        """
        Runs the ONNX model.
        :param outputs: outputs required
        :param feeds: inputs
        :param report_cmp: used as a reference,
            every intermediate results is compare to every existing one,
            if not empty, it is an instance of
            :class:`onnx_diagnostic.reference.ReportResultComparison`
        :return: output tensors.
        """
        use_numpy = any(isinstance(t, np.ndarray) for t in feeds.values())
        if use_numpy:
            feeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
        if outputs is None:
            outputs = self.output_names
        # sets constants
        for k, v in self.constants.items():
            r = self.runtime_info[k]
            if not r.has_value:
                r.set_value(
                    torch_ops.OpRunTensor(
                        v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
                        is_constant=True,
                        may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
                    )
                )
            if self.verbose:
                print(f"+C {r.name}: {r.string_type()}")
        # inputs
        for k, v in feeds.items():
            r = self.runtime_info[k]
            r.set_value(
                torch_ops.OpRunTensor(
                    v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
                    is_constant=False,
                    may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
                )
            )
            if self.verbose:
                print(f"+I {r.name}: {r.string_type()}")
        # node execution
        for it, kernel in enumerate(self.kernels):
            if kernel is not None:
                if self.verbose:
                    print(
                        f"{kernel.__class__.__name__}"
                        f"({', '.join(kernel.input)}) -> "
                        f"{', '.join(kernel.output)}"
                    )
                # kernel execution
                inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
                if kernel.has_subgraphs():
                    res = kernel.run(*inputs, context=self.runtime_info)  # type: ignore[call-arg]
                else:
                    res = kernel.run(*inputs)
                if isinstance(res, tuple):
                    # outputs
                    assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
                        f"Unexpected output type {[type(o) for o in res]} "
                        f"for kernel {type(kernel)}."
                    )
                    for name, t in zip(kernel.output, res):
                        self.runtime_info[name].set_value(t)
                    if self.verbose:
                        for name in kernel.output:
                            print(f"+R {name}: {self.runtime_info[name].string_type()}")
                else:
                    assert isinstance(
                        res, torch_ops.OpRunValue
                    ), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
                    self.runtime_info[kernel.output[0]].set_value(res)
                    if self.verbose:
                        print(
                            f"+R {kernel.output[0]}: "
                            f"{self.runtime_info[kernel.output[0]].string_type()}"
                        )
                if report_cmp:
                    reported = report_cmp.report(
                        dict(
                            zip(
                                kernel.output,
                                (
                                    tuple((r.tensor if r else None) for r in res)  # type: ignore[attr-defined]
                                    if isinstance(res, tuple)
                                    else ((res.tensor if res else None),)  # type: ignore[attr-defined]
                                ),
                            )
                        )
                    )
                    if self.verbose > 1:
                        print(f"  -- report {len(reported)} comparisons")
            # free intermediate results
            for name in self.last_used[it]:
                self.runtime_info[name].clean_value()
                if self.verbose:
                    print(f"- clean {name}")
        assert all(
            self.runtime_info[o].value is not None for o in outputs
        ), "Not implemented yet when one output is None."
        fres = [self.runtime_info[o].value.tensor for o in outputs]  # type: ignore[union-attr]
        if self.verbose:
            print(f"++ outputs {', '.join(outputs)}")
        # clean previous execution
        for k in feeds:
            self.runtime_info[k].clean_value()
            if self.verbose:
                print(f"- clean {k}")
        for o in outputs:
            self.runtime_info[o].clean_value()
            if self.verbose:
                print(f"- clean {o}")
        if use_numpy:
            return [None if a is None else a.detach().cpu().numpy() for a in fres]
        return fres 
[docs]
    def run_with_values(
        self,
        *args: Optional[torch_ops.OpRunTensor],
        context: Optional[Dict[str, RuntimeValue]] = None,
    ) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
        """
        Runs the ONNX model. The signature is different.
        This method is called by every kernel hokding a subgraph.
        The local variables are stored in `context`.
        :param args: inputs
        :param context: local context for the execution of subgraphs
        :return: output OpRunTensor
        """
        assert all(
            isinstance(a, torch_ops.OpRunValue) for a in args
        ), f"Unexpected type in args: {[type(a) for a in args]}"
        outputs = self.output_names
        context = context or {}
        # sets constants
        for k, v in self.constants.items():
            r = self.runtime_info[k]
            if not r.has_value:
                r.set_value(
                    torch_ops.OpRunTensor(
                        v.to(self.CUDA) if r.is_shape is False and self.on_cuda else v,
                        is_constant=True,
                        may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
                    )
                )
        # inputs
        for k, v in zip(self.input_names, args):
            r = self.runtime_info[k]
            r.set_value(
                torch_ops.OpRunTensor(None) if v is None else v.__class__(v.tensor_or_sequence)
            )
        # node execution
        for it, kernel in enumerate(self.kernels):
            if kernel is not None:
                # kernel execution
                inputs = [
                    (
                        (
                            self.runtime_info[i].value
                            if i in self.runtime_info
                            else context[i].value
                        )
                        if i
                        else None
                    )
                    for i in kernel.input
                ]
                res = kernel.run(*inputs)
                if isinstance(res, tuple):
                    # outputs
                    assert all(isinstance(o, torch_ops.OpRunTensor) for o in res), (
                        f"Unexpected output type {[type(o) for o in res]} "
                        f"for kernel {type(kernel)}."
                    )
                    for name, t in zip(kernel.output, res):
                        self.runtime_info[name].set_value(t)
                else:
                    assert isinstance(
                        res, torch_ops.OpRunValue
                    ), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
                    self.runtime_info[kernel.output[0]].set_value(res)
            # free intermediate results
            for name in self.last_used[it]:
                self.runtime_info[name].clean_value()
        assert all(
            self.runtime_info[o].value is not None for o in outputs
        ), "Not implemented yet when one output is None."
        res2 = [self.runtime_info[o].value.copy() for o in outputs]  # type: ignore[assignment, union-attr]
        # clean previous execution
        for k in self.input_names:
            self.runtime_info[k].clean_value()
        for o in self.output_names:
            self.runtime_info[o].clean_value()
        return res2[0] if len(res2) == 1 else tuple(res2)  # type: ignore[index, return-value, arg-type]