from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
from .npx_array_api import BaseArrayApi, ArrayApiError
from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN
from .npx_types import (
    DType,
    ElemType,
    OptParType,
    ParType,
    TensorType,
    OptTensorType,
    TupleType,
)
[docs]class Par:
    """
    Defines a named parameter.
    :param name: parameter name
    :param dtype: parameter type (bool, int, str, float)
    :param value: value of the parameter if known
    :param parent_op: node type it belongs to
    """
    def __init__(
        self,
        name: str,
        dtype: ParType,
        value: Optional[Any] = None,
        parent_op: Optional[Tuple[str, str, int]] = None,
    ):
        if not issubclass(dtype, ParType):
            raise TypeError(
                f"dtype for parameter {name!r} must be of " f"ParType not {dtype}."
            )
        if parent_op is None:
            raise ValueError(f"parent_op must be filled for paramenter {name!r}.")
        self.name = name
        self.dtype = dtype
        self.value = value
        self.parent_op = parent_op
    def __repr__(self):
        "usual"
        if self.value is None:
            return (
                f"{self.__class__.__name__}({self.name!r}, {self.dtype.type_name()}, "
                f"parent_op={self.parent_op!r})"
            )
        return (
            f"{self.__class__.__name__}"
            f"({self.name!r}, {self.dtype.type_name()}, {self.value!r}, "
            f"parent_op={self.parent_op!r})"
        )
    @property
    def onnx_type(self):
        "Returns the corresponding onnx type."
        return self.dtype.onnx_type()
    def __eq__(self, x):
        "Should not be used."
        raise NotImplementedError("__eq__ should not be used.")
    def __neq__(self, x):
        "Should not be used."
        raise NotImplementedError("__neq__ should not be used.")
    def __lt__(self, x):
        "Should not be used."
        raise NotImplementedError("__lt__ should not be used.")
    def __gt__(self, x):
        "Should not be used."
        raise NotImplementedError("__gt__ should not be used.")
    def __le__(self, x):
        "Should not be used."
        raise NotImplementedError("__le__ should not be used.")
    def __ge__(self, x):
        "Should not be used."
        raise NotImplementedError("__ge__ should not be used.") 
[docs]class ManyIdentity:
    """
    Holds several instances of :class:`Var <onnx_array_api.npx.npx_var.Var>`.
    """
    def __init__(self, *inputs, input_indices=None):
        self.inputs = inputs
        self.onnx_op = None
        if input_indices is None:
            self.input_indices = [0 for i in self.inputs]
        else:
            self.input_indices = input_indices
        self.n_var_outputs = len(self.inputs)
        self.onnx_op_kwargs = {}
        self._prefix = "ManyIdentity_"
    def __repr__(self) -> str:
        "usual"
        args = list(map(repr, self.inputs))
        if max(self.input_indices) > 0:
            args.append(f"input_indices={self.input_indices}")
        s = ", ".join(args)
        return f"{self.__class__.__name__}({s})"
    def __len__(self):
        "Returns the number of merged variables."
        return len(self.inputs)
    def __getitem__(self, i):
        "Returns the ith elements."
        return self.inputs[i]
[docs]    def to_onnx(
        self,
        target_opsets: Optional[Dict[str, int]] = None,
        as_function: bool = False,
        name: Optional[str] = None,
        domain: Optional[str] = None,
        attributes: Optional[List[str]] = None,
        constraints: Optional[Dict[Any, TensorType]] = None,
        ir_version: Optional[int] = None,
    ) -> Union[ModelProto, FunctionProto, List[Any]]:
        """
        Converts the recursive graph to ONNX.
        :param target_opsets: dictionary `{opset: version}`, if None,
            it is replaced by `DEFAULT_OPSETS`
        :param as_function: conversion to :class:`onnx.FunctionProto`
            or :class:`onnx.ModelProto`
        :param name: function name if *as_function* is True
        :param domain: function domain if *as_function* is True
        :param attributes: function attributes if any
        :param constraints: specifies a precise type for the type
            constraints when a function allows more than one type,
            this works if there is only one variable to be converted
        :return: ModelProto, FunctionProto
        """
        from .npx_graph_builder import _GraphBuilder
        # Var.to_onnx
        if target_opsets is None:
            target_opsets = DEFAULT_OPSETS.copy()
        g = _GraphBuilder(
            target_opsets,
            as_function=as_function,
            name=name,
            domain=domain,
            attributes=attributes,
            constraints=constraints,
            ir_version=ir_version,
        )
        done = set()
        outputs = []
        for var in self.inputs:
            vs = var._get_vars()
            for var2 in vs:
                key = id(var2)
                if key in done:
                    continue
                g.append(var2)
                done.add(key)
            outputs.append(vs[-1])
        onx = g.to_onnx(output_vars=outputs)
        if as_function:
            if len(outputs) != len(onx.output):
                raise RuntimeError(
                    f"Mismatch number of outputs, expecting {len(outputs)}, "
                    f"got ({len(onx.output)})."
                )
            if g.functions_:
                return [g.functions_, onx]
            return onx
        if len(outputs) != len(onx.graph.output):
            raise RuntimeError(
                f"Mismatch number of outputs, expecting {len(outputs)}, "
                f"got ({len(onx.graph.output)})."
            )
        return onx  
[docs]class Var(BaseArrayApi):
    """
    Defines a variable, a result...
    :param inputs: list of inputs
    :param op: apply on operator on the inputs
    :param inline: True to reduce the use of function and inline
        small functions, this only applies if *op* is a function
    :param n_var_outputs: number of the operator outputs
    :param input_indices: to select a specific output from the input
        operator
    :param kwargs: operator attributes
    Private attribute:
    :param onnx_input_type_: names given to the variables
    """
    def __array_namespace__(self, api_version: Optional[str] = None):
        """
        Raises an exception if called.
        """
        raise RuntimeError(
            f"This function should never be called for class {type(self)}. "
            f"It should be called for an eager tensor."
        )
    @staticmethod
    def get_cst_var():
        from .npx_core_api import cst, var
        return cst, var
    class _setter_do:
        def __init__(self, parent: "Var", *args):
            self.parent = parent.self_var
            self.args = args
        def __call__(self, new_values):
            """
            Returns a copy of `self.parent` where values
            whose indices are indicated by `args` and new
            values by `new_values`.
            """
            if len(self.args) == 1 and isinstance(self.args[0], (int, slice)):
                return self._setitem1_slice(self.args[0], new_values)
            if len(self.args) == 1 and isinstance(self.args[0], Var):
                return self._setitem1_where(self.args[0], new_values)
            raise NotImplementedError(
                f"This expression is not yet implemented for args={self.args}."
            )
        def _setitem1_where(self, index, new_values):
            cst, var = Var.get_cst_var()
            if isinstance(new_values, (int, float, bool)):
                new_values = np.array(new_values)
            if isinstance(new_values, np.ndarray):
                value = var(cst(new_values), self.parent, op="CastLike")
            elif isinstance(new_values, Var):
                value = new_values
            else:
                raise TypeError(f"Unexpected type for new_values: {type(new_values)}.")
            return var(index, value, self.parent, op="Where")
        def _setitem1_slice(self, index, new_values):
            cst, var = Var.get_cst_var()
            if isinstance(index, slice):
                start = 0 if index.start is None else index.start
                stop = index.stop
                step = index.step
            elif isinstance(index, int):
                start, stop, step = index, index + 1, 1
            else:
                raise NotImplementedError(
                    f"Unable to assign new values due to "
                    f"unexpected type {type(index)!r}."
                )
            inp = self.parent
            if stop is None and isinstance(new_values, np.ndarray):
                stop = start + new_values.size
            if stop is None:
                raise NotImplementedError(f"No implementation if stop is  {stop}.")
            indices = np.arange(start, stop, step or 1).astype(np.int64)
            if isinstance(new_values, np.ndarray):
                values = new_values
            else:
                values = np.full(indices.shape, new_values)
            return var(inp, cst(indices), cst(values), op="ScatterElements", axis=0)
    class _setter:
        def __init__(self, parent: "Var"):
            self.parent = parent
        def __getitem__(self, *args):
            return Var._setter_do(self.parent, *args)
    def __init__(
        self,
        *inputs: List[Any],
        op: Optional[
            Union[Callable, str, Tuple[str, str], FunctionProto, ModelProto, NodeProto]
        ] = None,
        dtype: Optional[Union[type, DType]] = None,
        inline: bool = False,
        n_var_outputs: int = 1,
        input_indices: Optional[List[int]] = None,
        **kwargs,
    ):
        self.inputs = list(inputs)
        self.n_var_outputs = n_var_outputs
        self.inline = inline
        self._annotation = None
        if op is None:
            self.onnx_op = None  # a constant
        elif isinstance(op, tuple):
            self.onnx_op = op  # domain, operator name
        elif isinstance(op, str):
            self.onnx_op = ("", op)  # operator name
        elif isinstance(op, (FunctionProto, ModelProto, NodeProto)):
            self.onnx_op = (ONNX_DOMAIN, op)
        else:
            self.onnx_op = (None, op)  # function to call
        self.onnx_op_kwargs = kwargs
        self._prefix = None
        if isinstance(dtype, DType):
            # regular parameter
            self.onnx_op_kwargs["dtype"] = dtype
        elif hasattr(dtype, "type_name"):
            self.dtype = dtype
        elif dtype is None:
            self.dtype = None
        else:
            raise TypeError(f"Unexpected type {type(dtype)} for dtype.")
        updates = {}
        for i, inp in enumerate(self.inputs):
            if isinstance(inp, type):
                raise TypeError(f"Unexpected type for input {i} - {inp}.")
            if isinstance(inp, Var):
                updates[i] = inp.self_var
            if not isinstance(inp, np.ndarray):
                continue
            if inp.size > 0 and isinstance(inp.ravel()[0], (np.ndarray, Var)):
                raise TypeError(
                    f"Unexpected type for input {i}: {type(inp)}, "
                    f"{inp.ravel()[0]}, op={op!r}"
                )
        # This step is needed when Var.__setitem__ was called to
        # modify the variable.
        for i, v in updates.items():
            self.inputs[i] = v
        self.inputs = tuple(self.inputs)
        if input_indices is None:
            self.input_indices = [0 for i in self.inputs]
        elif not isinstance(input_indices, list):
            raise TypeError(
                f"input_indices is {type(input_indices)} "
                f"but len(inputs)={len(inputs)}."
            )
        else:
            self.input_indices = input_indices
        if len(self.input_indices) != len(self.inputs):
            raise RuntimeError(
                f"length mismatch len(self.input_indices)="
                f"{len(self.input_indices)} != len(self.inputs)="
                f"{len(self.inputs)}."
            )
        if self.onnx_op is None:
            if not isinstance(self, (Input, Cst)):
                raise RuntimeError(f"This case is not allowed: {self!r}.")
        self.set = Var._setter(self)
        self.current_var_ = None
    @property
    def annotation(self):
        """Returns a type if known for the Var itself."""
        if self._annotation is None:
            if "dtype" in self.onnx_op_kwargs:
                dtype = self.onnx_op_kwargs["dtype"]
                if isinstance(dtype, DType):
                    return TensorType[dtype]
        return self._annotation
    @property
    def self_var(self):
        """
        Returns itself or the variable corresponding to its
        state after a call to `__setitem__`.
        """
        if not hasattr(self, "current_var_"):
            raise AttributeError(
                f"Class {type(self)} is missing attribute 'current_var_'."
            )
        return self if self.current_var_ is None else self.current_var_
    def __call__(self):
        return self.self_var
    def __repr__(self) -> str:
        "usual"
        args = []
        for inp in self.inputs:
            n = inp.__class__.__name__
            args.append(f"{n[0]}.")
        if self.onnx_op is not None:
            args.append(f"op={self.onnx_op!r}")
        if self.n_var_outputs != 1:
            args.append(f"n_var_outputs={self.n_var_outputs!r}")
        if max(self.input_indices) != 0:
            args.append(f"input_indices={self.input_indices!r}")
        for k, v in sorted(self.onnx_op_kwargs.items()):
            args.append(f"{k}={v!r}")
        res = f"{self.__class__.__name__}({', '.join(args)})"
        return res
[docs]    def set_onnx_name(self, prefix: str):
        """
        Forces this variable to get this name during
        :param prefix: prefix
        """
        self._prefix = prefix 
    def _get_vars(self):
        vs = []
        stack = [self.self_var]
        replacement = {}
        replacement_cst = {}
        deleted = []
        while len(stack) > 0:
            var = stack.pop()
            key = id(var)
            if key in replacement:
                while key in replacement:
                    var = replacement[key]
                    key = id(var)
            if var.onnx_op is not None and var.onnx_op[0] is None and var.inline:
                fct = var.onnx_op[1]
                applied = fct(*var.inputs, **var.onnx_op_kwargs)
                if isinstance(applied, (ManyIdentity, Var)):
                    stack.append(applied)
                    replacement[id(var)] = applied
                    deleted.append(var)
                    continue
                raise TypeError(
                    f"Unexpected type {type(applied)} as output of " f"function {fct}."
                )
            vs.append(var)
            for i in reversed(var.inputs):
                if isinstance(i, Var):
                    stack.insert(0, i)
                    continue
                if isinstance(i, np.ndarray):
                    cst = Var.get_cst_var()[0]
                    replacement_cst[id(i)] = cst(i)
                    continue
                if isinstance(i, (int, float, bool)):
                    cst = Var.get_cst_var()[0]
                    replacement_cst[id(i)] = cst(np.array(i))
                    continue
                if isinstance(i, tuple):
                    if all(map(lambda x: isinstance(x, int), i)):
                        cst = Var.get_cst_var()[0]
                        replacement_cst[id(i)] = cst(np.array(list(i), dtype=np.int64))
                        continue
                    if any(map(lambda t: isinstance(t, Var), i)):
                        raise TypeError(
                            f"Unexpected types in tuple "
                            f"({[type(t) for t in i]}), "
                            f"function {self.f} from module {self.f.__module__!r}."
                        )
                    raise TypeError(
                        f"Unsupported tuple {i!r}, "
                        f"function {self.f} from module {self.f.__module__!r}."
                    )
                if i is None:
                    continue
                raise TypeError(
                    f"Unexpected type {type(i)} for an input of node {var}."
                )
        res = list(reversed(vs))
        # replacement: a node calling a function can either
        # remains as a call to a local function or the code
        # of the function can replace the call inline.
        # replacement keeps a map of function call to replace
        # by the return itself to avoid calling the same function
        # twice.
        new_res = []
        for r in res:
            new_inputs = []
            new_indices = []
            repl = False
            for v, ind in zip(r.inputs, r.input_indices):
                key = id(v)
                if key in replacement:
                    while key in replacement:
                        var = replacement[key]
                        key = id(var)
                    new_inputs.append(var)
                    new_indices.append(ind)
                    repl = True
                else:
                    new_inputs.append(v)
                    new_indices.append(ind)
            if repl:
                new_r = r.replace_inputs(new_inputs, input_indices=new_indices)
                replacement[id(r)] = new_r
                new_res.append(new_r)
            else:
                new_res.append(r)
        # check the graph is consistent
        known = {}
        for r in new_res:
            known[id(r)] = r
            if isinstance(r, (Cst, Input)):
                continue
            for ind, i in enumerate(r.inputs):
                if i is None:
                    # optional input
                    continue
                if id(i) in replacement_cst:
                    # constant to replace
                    continue
                if id(i) not in known:
                    raise RuntimeError(
                        f"An input {ind} ({id(i)}, type={type(i)}) from "
                        f"{id(r)}-{r} is not known, it is not produced by a "
                        f"previous var (scheduled for replacement: "
                        f"{id(i) in replacement}). This also happens if "
                        f"a constant is not wrapped by 'cst(.)'."
                    )
        return new_res
    @property
    def is_function(self):
        """
        Tells if this variable encapsulate a function.
        """
        return self.onnx_op is not None and self.onnx_op[0] is None
[docs]    def to_onnx(
        self,
        target_opsets: Optional[Dict[str, int]] = None,
        as_function: bool = False,
        name: Optional[str] = None,
        domain: Optional[str] = None,
        attributes: Optional[List[str]] = None,
        constraints: Optional[Dict[Any, TensorType]] = None,
        ir_version: Optional[int] = None,
    ) -> Union[ModelProto, FunctionProto, List[Any]]:
        """
        Converts the recursive graph to ONNX.
        :param target_opsets: dictionary `{opset: version}`
        :param as_function: conversion to :class:`onnx.FunctionProto`
            or :class:`onnx.ModelProto`
        :param name: function name if *as_function* is True
        :param domain: function domain if *as_function* is True
        :param attributes: function attributes if any
        :param constraints: specifies a precise type for the type
            constraints when a function allows more than one type,
            this works if there is only one variable to be converted
        :return: ModelProto, FunctionProto
        """
        from .npx_graph_builder import _GraphBuilder
        # Var.to_onnx
        if target_opsets is None:
            target_opsets = DEFAULT_OPSETS
        vs = self._get_vars()
        g = _GraphBuilder(
            target_opsets,
            as_function=as_function,
            name=name,
            domain=domain,
            attributes=attributes,
            constraints=constraints,
            ir_version=ir_version,
        )
        for var in vs:
            g.append(var)
        onx = g.to_onnx()
        if as_function and len(g.functions_) > 0:
            return [g.functions_, onx]
        return onx 
    # Operators
    def __iter__(self):
        """
        The :epkg:`Array API` does not define this function (2022/12).
        This method raises an exception with a better error message.
        """
        raise ArrayApiError(
            f"Iterators are not implemented in the generic case. "
            f"Every function using them cannot be converted into ONNX "
            f"(Var - {type(self)})."
        )
    def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var":
        var = Var.get_cst_var()[1]
        if isinstance(ov, (int, float, bool, np.ndarray, Cst)):
            return var(self.self_var, var(ov, self.self_var, op="CastLike"), op=op_name)
        return var(self.self_var, ov, op=op_name, **kwargs)
    def _binary_op_right(self, ov: "Var", op_name: str, **kwargs) -> "Var":
        var = Var.get_cst_var()[1]
        if isinstance(ov, (int, float, bool, np.ndarray, Cst)):
            return var(var(ov, self.self_var, op="CastLike"), self.self_var, op=op_name)
        return var(ov, self.self_var, op=op_name, **kwargs)
    def __neg__(self) -> "Var":
        """
        Automatically adds operator `Neg` to the graph.
        It does not cast automatically.
        """
        var = Var.get_cst_var()[1]
        return var(self.self_var, op="Neg")
    def __invert__(self) -> "Var":
        """
        Automatically adds operator `BitwiseNot` to the graph.
        It does not cast automatically.
        """
        var = Var.get_cst_var()[1]
        return var(self.self_var, op="BitwiseNot")
    def __add__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Add` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Add")
    def __radd__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Add` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "Add")
    def __sub__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Sub` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Sub")
    def __rsub__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Sub` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "Sub")
    def __mul__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Mul` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Mul")
    def __rmul__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Mul` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "Mul")
    def __matmul__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `MatMul` to the graph.
        It does not cast automatically.
        `__rmatmul__` would not be called as a numpy array
        overwrites `__matmul__` on its side.
        """
        return self._binary_op(ov, "MatMul")
    def __truediv__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Div` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Div")
    def __rtruediv__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Div` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "Div")
    def __mod__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Mod` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Mod")
    def __rmod__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Mod` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "Mod")
    def __pow__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Pow` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Pow")
    def __rpow__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Pow` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "Pow")
    def __lt__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Less` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Less")
    def __le__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `LessOrEqual` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "LessOrEqual")
    def __gt__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Greater` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Greater")
    def __ge__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `GreaterOrEqual` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "GreaterOrEqual")
    def __eq__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Equal` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "Equal")
    def __ne__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `Not + Equal` to the graph.
        It does not cast automatically.
        """
        var = Var.get_cst_var()[1]
        return var(self._binary_op(ov, "Equal"), op="Not")
    def __lshift__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitShift` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "BitShift", direction="LEFT")
    def __rshift__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitShift` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "BitShift", direction="RIGHT")
    def __and__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitwiseAnd` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "BitwiseAnd")
    def __rand__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitwiseAnd` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "BitwiseAnd")
    def __or__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitwiseOr` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "BitwiseOr")
    def __ror__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitwiseOr` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "BitwiseOr")
    def __xor__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitwiseXor` to the graph.
        It does not cast automatically.
        """
        return self._binary_op(ov, "BitwiseXor")
    def __rxor__(self, ov: "Var") -> "Var":
        """
        Automatically adds operator `BitwiseXor` to the graph.
        It does not cast automatically.
        """
        return self._binary_op_right(ov, "BitwiseXor")
    @property
    def T(self) -> "Var":
        "Transpose."
        var = Var.get_cst_var()[1]
        return var(self.self_var, op="Transpose", perm=[1, 0])
[docs]    def astype(self, dtype) -> "Var":
        "Cast"
        var = Var.get_cst_var()[1]
        if isinstance(dtype, Var):
            return var(self.self_var, dtype, op="CastLike")
        if not isinstance(dtype, DType):
            raise TypeError(f"dtype cannot be {type(dtype)}.")
        return var(self.self_var, op="Cast", to=dtype) 
    @property
    def shape(self) -> "Var":
        "Shape"
        var = Var.get_cst_var()[1]
        return var(self.self_var, op="Shape")
[docs]    def reshape(self, shape: "Var") -> "Var":
        "Reshape"
        cst, var = Var.get_cst_var()
        if isinstance(shape, (tuple, list)):
            shape = np.array(shape, dtype=np.int64)
        else:
            shape = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape")
        return var(self.self_var, shape, op="Reshape") 
[docs]    def reduce_function(
        self,
        reduce_op,
        axis: OptTensorType[ElemType.int64, "I"] = None,
        keepdims: ParType[int] = 0,
    ) -> "Var":
        "See :func:`numpy.sum` or any other reduce function."
        var = Var.get_cst_var()[1]
        if axis is None:
            return var(self.self_var, op=reduce_op, keepdims=keepdims)
        if isinstance(axis, int):
            axis = [axis]
        if isinstance(axis, (tuple, list)):
            cst = Var.get_cst_var()[0]
            axis = cst(np.array(axis, dtype=np.int64))
        return var(self.self_var, axis, op=reduce_op, keepdims=keepdims) 
[docs]    def sum(
        self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
    ) -> "Var":
        "See :func:`numpy.sum`."
        return self.reduce_function("ReduceSum", axis=axis, keepdims=keepdims) 
[docs]    def mean(
        self, axis: OptParType[TupleType[int]] = None, keepdims: ParType[int] = 0
    ) -> "Var":
        "See :func:`numpy.mean`."
        return self.reduce_function("ReduceMean", axis=axis, keepdims=keepdims) 
[docs]    def min(
        self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
    ) -> "Var":
        "See :func:`numpy.min`."
        return self.reduce_function("ReduceMin", axis=axis, keepdims=keepdims) 
[docs]    def max(
        self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
    ) -> "Var":
        "See :func:`numpy.max`."
        return self.reduce_function("ReduceMax", axis=axis, keepdims=keepdims) 
[docs]    def prod(
        self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
    ) -> "Var":
        "See :func:`numpy.prod`."
        return self.reduce_function("ReduceProd", axis=axis, keepdims=keepdims) 
[docs]    def copy(self) -> "Var":
        """
        Returns a copy of self (use of Identity node).
        """
        var = Var.get_cst_var()[1]
        return var(self.self_var, op="Identity") 
[docs]    def flatten(self) -> "Var":
        """
        Flattens a matrix (see :meth:`numpy.ndarray.flatten`).
        :param axis: only flatten from axis to the end.
        :return: :class:`Var <onnx_array_api.npx.npx_var.Var>`
        """
        cst, var = Var.get_cst_var()
        return var(
            var(self.self_var, op="Flatten", axis=0),
            cst(np.array([0], dtype=np.int64)),
            op="Squeeze",
        ) 
[docs]    def get(self, index: int) -> "Var":
        """
        If an operator or a function returns more than one output,
        this takes only one.
        :param index: index of the output to select
        :return: Var
        """
        if index < 0 or index >= self.n_var_outputs:
            raise ValueError(
                f"index={index} must be positive and < {self.n_var_outputs} "
                f"for var={self!r}."
            )
        return Var(self.self_var, input_indices=[index], op="Identity") 
    def __getitem__(self, index: Any) -> "Var":
        """
        Deals with multiple scenarios.
        * *index* is an integer and the object produces multiple
          outputs and this returns one of them (**scenario 0**)
        * *index* is an integer or a slice, a tuple of integers and slices,
          example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**)
        * *index* is an *ONNX* object (more precisely an instance of
          :class:`Var <onnx_array_api.npx.npx_var.Var>`),
          then the method assumes it is an array of
          boolean to select a subset of the tensor along the first axis,
          example: `mat[mat == 0]` (**scenario 2**)
        """
        cst, var = Var.get_cst_var()
        if self.n_var_outputs != 1:
            # Multioutput
            if not isinstance(index, int):
                raise TypeError(
                    f"Only indices are allowed when selecting an output, "
                    f"not {type(index)})."
                )
            return self.get(index)
        if isinstance(index, Var):
            # scenario 2
            # we rely on the annotation if it exists
            if index.annotation is None:
                dtype_bool = True
            elif issubclass(index.annotation, TensorType):
                if index.annotation.supports_dtype(
                    DType(TensorProto.INT64)
                ) or index.annotation.supports_dtype(DType(TensorProto.INT32)):
                    dtype_bool = False
                elif index.annotation.supports_dtype(DType(TensorProto.BOOL)):
                    dtype_bool = True
                else:
                    raise TypeError(
                        f"Unexpected dtype for annotation={index.annotation!r} "
                        f"for index={index!r}."
                    )
            else:
                raise TypeError(
                    f"Unexpected annotation={index.annotation!r} "
                    f"for index={index!r}."
                )
            if dtype_bool:
                # TODO: fix this when index is an integer and the annotation unknown
                # it needs to support subgraph and tests
                new_shape = cst(np.array([-1], dtype=np.int64))
                new_self = self.reshape(new_shape)
                new_index = index.reshape(new_shape)
                return var(new_self, new_index, op="Compress")
            # dtype is int
            return var(self, index, axis=0, op="Gather")
        if isinstance(index, int):
            # Use Gather instead.
            return var(self, cst(np.array(index, dtype=np.int64)), axis=0, op="Gather")
        if not isinstance(index, tuple):
            index = (index,)
        elif not index:
            # The array contains a scalar and it needs to be returned.
            return var(self, op="Identity")
        # only one integer?
        ni = None
        ax = None
        for i, a in enumerate(index):
            if isinstance(a, int):
                if ni is None:
                    ni = i
                    ax = a
                else:
                    ax = None
                    ni = None
                    break
            if (
                isinstance(a, slice)
                and a.start is None
                and a.stop is None
                and a.step is None
            ):
                continue
            ax = None
            ni = None
            break
        if ni is not None and ax is not None:
            # Use Gather instead.
            return var(self, cst(np.array(ni, dtype=np.int64)), axis=ax, op="Gather")
        # scenario 1
        starts = []
        ends = []
        axes = []
        steps = []
        axis_squeeze = []
        needs_shape = []
        for i, ind in enumerate(index):
            if isinstance(ind, int):
                starts.append(ind)
                ends.append(ind + 1)
                axes.append(i)
                steps.append(1)
                axis_squeeze.append(i)
                continue
            if isinstance(ind, slice):
                if ind.start is None and ind.stop is None and ind.step is None:
                    continue
                start = 0 if ind.start is None else ind.start
                end = (None, i) if ind.stop is None else ind.stop
                step = 1 if ind.step is None else ind.step
                starts.append(start)
                ends.append(end)
                axes.append(i)
                steps.append(step)
                if isinstance(end, tuple):
                    needs_shape.append(len(ends) - 1)
                elif isinstance(end, Var):
                    needs_shape.append(end)
                continue
            raise NotImplementedError(f"Not implemented for type {type(ind)!r}.")
        if max(steps) == min(steps) == 1:
            steps = None
        else:
            steps = np.array(steps, dtype=np.int64)
        starts = np.array(starts, dtype=np.int64)
        axes = np.array(axes, dtype=np.int64)
        if needs_shape:
            shape = self.shape
            conc = []
            for e in ends:
                if isinstance(e, tuple):
                    conc.append(
                        var(shape, cst(np.array([e[1]], np.int64)), op="Gather")
                    )
                elif isinstance(e, Var):
                    conc.append(e.reshape(np.array([-1], dtype=np.int64)))
                else:
                    conc.append(np.array([e], dtype=np.int64))
            if len(conc) > 1:
                conc_cst = [v if isinstance(v, Var) else cst(v) for v in conc]
                ends = var(*conc_cst, op="Concat", axis=0)
            else:
                ends = conc[0]
        else:
            ends = np.array(ends, dtype=np.int64)
        sliced_args = [starts, ends, axes]
        if steps is not None:
            sliced_args.append(steps)
        sliced_args_cst = [v if isinstance(v, Var) else cst(v) for v in sliced_args]
        sliced = var(self.self_var, *sliced_args_cst, op="Slice")
        if axis_squeeze:
            return var(
                sliced,
                cst(np.array(axis_squeeze, dtype=np.int64)),
                op="Squeeze",
            )
        return sliced
    def __setitem__(self, index, values):
        new_op = self.set[index](values)
        self.current_var_ = new_op
        self.input_indices = None 
[docs]class Cst(Var):
    """
    Defines a constant.
    """
    def __init__(self, cst: Any):
        if isinstance(cst, np.ndarray):
            Var.__init__(self, cst, op="Identity")
        elif isinstance(cst, bool):
            Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
        elif isinstance(cst, int):
            Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
        elif isinstance(cst, float):
            Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
        elif isinstance(cst, list):
            if all(map(lambda t: isinstance(t, bool), cst)):
                Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
            elif all(map(lambda t: isinstance(t, (int, bool)), cst)):
                Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
            elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)):
                Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
            else:
                raise ValueError(
                    f"Unable to convert cst (type={type(cst)}), " f"value={cst}."
                )
        else:
            raise NotImplementedError(
                f"Constant of type {type(cst)} are not implemented yet. "
                f"You should not use 'float32(x)' but 'array(x, dtype=float32)'."
            )
        self._prefix = "cst"