from typing import Any, Callable, Dict, List, Optional
from .dispatcher import Dispatcher
from .oxs_opset import OxsOpset, Var
[docs]
class OxsDispatcher(Dispatcher):
    """
    If :class:`DynamoInterpreter
    <experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter>`
    cannot find any converting function for a specific function,
    it tries to find an existing one in :epkg:`onnxscript`.
    The converting function from onnxscript is run in trace only mode.
    The variable and functions op, Rank, IsScalar are replaced by
    `op = OwsOpset()`, `op.Rank`, `op.Scalar`.
    onnxscript may have multiple overloaded functions.
    Right now, it takes the first one.
    :param verbose: verbose
    """
    def __init__(self, verbose: int = 0):
        super().__init__({}, verbose=verbose)
        self._submodule = None
    @property
    def submodules(self) -> Dict[str, Callable]:
        """
        Returns the submodules implementing torch functions.
        """
        if self._submodule is not None:
            return self._submodule
        from onnxscript.function_libs.torch_lib.ops import (
            core,
            fft,
            linalg,
            nn,
            prims,
            special,
            vision,
        )
        subs = {
            "onnxscript.function_libs.torch_lib.ops.core": core,
            "onnxscript.function_libs.torch_lib.ops.fft": fft,
            "onnxscript.function_libs.torch_lib.ops.linalg": linalg,
            "onnxscript.function_libs.torch_lib.ops.nn": nn,
            "onnxscript.function_libs.torch_lib.ops.prims": prims,
            "onnxscript.function_libs.torch_lib.ops.special": special,
            "onnxscript.function_libs.torch_lib.ops.vision": vision,
        }
        self._submodule = subs
        return subs
[docs]
    def fallback(
        self,
        name: Any,
        fct: Optional[Callable],
        args: List[Any],
        kwargs: Dict[str, Any],
        builder: "GraphBuilder",  # noqa: F821
    ) -> Optional[Callable]:
        """
        The function is called after the function converting an aten function
        into ONNX. *fct* is this function. It can be changed and just
        set when mapping was found.
        :param name: object or str
        :param fct: function found so far
        :param args: known arguments coming from the graph module
        :param kwargs: known named arguments coming from the graph module
        :param builder: GraphBuilder
        :return: callable
        """
        if fct is not None:
            # The conversion has been found.
            return fct
        from onnxscript.function_libs.torch_lib.registration import default_registry
        if hasattr(name, "__qualname__") and "::" in name.__qualname__:
            key = name.__qualname__
        else:
            key = str(name)
        if key.startswith("aten."):
            key = "aten::" + key[6:]
        if key not in default_registry:
            if self.verbose > 2:
                print(
                    f"[OxsDispatcher.fallback] unable to find any fallback "
                    f"for {name!r} or {key!r}"
                )
            return None
        regfct = default_registry[key]
        assert len(regfct.overloads) > 0, (
            f"Unable to find onnxscript submodule {fct.function.__module__!r}. "
            f"onnxscript has a function with no overloaded instances, "
            f"key={key!r}, name={name!r}{builder.get_debug_msg()}"
        )
        fct = regfct.overloads[0]
        assert fct.__module__ in self.submodules, (
            f"Unable to find onnxscript submodule {fct.__module__!r}. "
            f"The fallback to onnxscript is not implemented yet for function "
            f"key={key!r}, name={name!r}{builder.get_debug_msg()}"
        )
        if self.verbose > 3:
            print(
                f"[OxsDispatcher.fallback] found {len(regfct.overloads)} "
                f"overloads for {key!r} ({name!r}), taking the first one."
            )
        def wrapper(g, sts, outputs, *args, _fct=fct, _dispatcher=self, **kwargs):
            op = OxsOpset(g)
            vargs = [(Var(x) if isinstance(x, str) else x) for x in args]
            # rewrite op, Rank, IsScalar in every submodule
            old = self._update_oxs(op)
            # call the function
            res = _fct.__wrapped__(*vargs, **kwargs)
            # restore op, Rank, IsScalar
            self._restore_oxs(old)
            if isinstance(res, tuple):
                tres = tuple(r.name for r in res)
                cres = tres
            else:
                tres = res.name
                cres = (res.name,)
            if outputs is None:
                return tres
            # We need to rename.
            assert len(outputs) == len(cres), (
                f"Mismatched number of outputs, expecting {outputs!r} but got "
                f"{len(cres)} from {name!r} (key={key!r}){g.get_debug_msg()}"
            )
            for r, o in zip(cres, outputs):
                builder.op.Identity(r, outputs=[o], name=key.replace("::", "."))
            if len(outputs) == 1:
                return outputs[0]
            return tuple(outputs)
        return wrapper 
    def _update_oxs(self, op: OxsOpset):
        keep = {}
        for k, v in self.submodules.items():
            old = v.op, getattr(v, "IsScalar", None), getattr(v, "Rank", None)
            v.op = op
            v.Rank = op.Rank
            v.IsScalar = op.IsScalar
            keep[k] = old
        return keep
    def _restore_oxs(self, old):
        for k, v in self.submodules.items():
            v.op, v.IsScalar, v.Rank = old[k] 
[docs]
class OxsDebugDispatcher(OxsDispatcher):
    """
    Tries the fallback even if is not necessary to check
    it is working.
    :param verbose: verbosity
    :param raise_exc: fail or raise an exception
    The class can be used the following way.
    .. runpython::
        :showcode:
        :process:
        import torch
        from experimental_experiment.torch_models.llama_helper import get_llama_model
        from experimental_experiment.xbuilder import OptimizationOptions
        from experimental_experiment.torch_interpreter import to_onnx
        from experimental_experiment.torch_interpreter.oxs_dispatcher import (
            OxsDebugDispatcher,
        )
        with torch.no_grad():
            model, input_tensors = get_llama_model()
            input_tensors = input_tensors[0]
            to_onnx(
                model,
                input_tensors,
                input_names=[f"input{i}" for i in range(len(input_tensors))],
                options=OptimizationOptions(patterns=None),
                verbose=0,
                dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False),
            )"""
    def __init__(self, verbose: int = 0, raise_exc: bool = True):
        super(OxsDispatcher, self).__init__({}, verbose=verbose)
        self._submodule = None
        self.raise_exc = raise_exc
[docs]
    def fallback(
        self,
        name: Any,
        fct: Optional[Callable],
        args: List[Any],
        kwargs: Dict[str, Any],
        builder: "GraphBuilder",  # noqa: F821
    ) -> Optional[Callable]:
        if self.raise_exc:
            res = OxsDispatcher.fallback(self, name, None, args, kwargs, builder)
            res(builder, False, None, *args, **kwargs)
            if self.verbose > 1:
                print(f"[OxsDebugDispatcher.fallback] fallback verified for {name!r}: {res}")
        else:
            try:
                res = OxsDispatcher.fallback(self, name, None, args, kwargs, builder)
            except (
                AssertionError,
                AttributeError,
                RuntimeError,
                TypeError,
                ValueError,
            ) as e:
                if self.verbose > 1:
                    print(
                        f"[OxsDebugDispatcher.fallback] fallback "
                        f"failed for {name!r} with e={e}"
                    )
                return fct
            try:
                res(builder, False, None, *args, **kwargs)
            except (
                AssertionError,
                AttributeError,
                RuntimeError,
                TypeError,
                ValueError,
            ) as e:
                if self.verbose > 1:
                    print(
                        f"[OxsDebugDispatcher.fallback] fallback "
                        f"unverified for {name!r} with e={e}"
                    )
                return fct
        if self.verbose > 1:
            print(f"[OxsDebugDispatcher.fallback] fallback verified for {name!r} with {res}")
        return fct or res