from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
from .npx_array_api import BaseArrayApi, ArrayApiError
from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN
from .npx_types import (
DType,
ElemType,
OptParType,
ParType,
TensorType,
OptTensorType,
TupleType,
)
[docs]
class Par:
"""
Defines a named parameter.
:param name: parameter name
:param dtype: parameter type (bool, int, str, float)
:param value: value of the parameter if known
:param parent_op: node type it belongs to
"""
def __init__(
self,
name: str,
dtype: ParType,
value: Optional[Any] = None,
parent_op: Optional[Tuple[str, str, int]] = None,
):
if not issubclass(dtype, ParType):
raise TypeError(
f"dtype for parameter {name!r} must be of " f"ParType not {dtype}."
)
if parent_op is None:
raise ValueError(f"parent_op must be filled for paramenter {name!r}.")
self.name = name
self.dtype = dtype
self.value = value
self.parent_op = parent_op
def __repr__(self):
"usual"
if self.value is None:
return (
f"{self.__class__.__name__}({self.name!r}, {self.dtype.type_name()}, "
f"parent_op={self.parent_op!r})"
)
return (
f"{self.__class__.__name__}"
f"({self.name!r}, {self.dtype.type_name()}, {self.value!r}, "
f"parent_op={self.parent_op!r})"
)
@property
def onnx_type(self):
"Returns the corresponding onnx type."
return self.dtype.onnx_type()
def __eq__(self, x):
"Should not be used."
raise NotImplementedError("__eq__ should not be used.")
def __neq__(self, x):
"Should not be used."
raise NotImplementedError("__neq__ should not be used.")
def __lt__(self, x):
"Should not be used."
raise NotImplementedError("__lt__ should not be used.")
def __gt__(self, x):
"Should not be used."
raise NotImplementedError("__gt__ should not be used.")
def __le__(self, x):
"Should not be used."
raise NotImplementedError("__le__ should not be used.")
def __ge__(self, x):
"Should not be used."
raise NotImplementedError("__ge__ should not be used.")
[docs]
class ManyIdentity:
"""
Holds several instances of :class:`Var <onnx_array_api.npx.npx_var.Var>`.
"""
def __init__(self, *inputs, input_indices=None):
self.inputs = inputs
self.onnx_op = None
if input_indices is None:
self.input_indices = [0 for i in self.inputs]
else:
self.input_indices = input_indices
self.n_var_outputs = len(self.inputs)
self.onnx_op_kwargs = {}
self._prefix = "ManyIdentity_"
def __repr__(self) -> str:
"usual"
args = list(map(repr, self.inputs))
if max(self.input_indices) > 0:
args.append(f"input_indices={self.input_indices}")
s = ", ".join(args)
return f"{self.__class__.__name__}({s})"
def __len__(self):
"Returns the number of merged variables."
return len(self.inputs)
def __getitem__(self, i):
"Returns the ith elements."
return self.inputs[i]
[docs]
def to_onnx(
self,
target_opsets: Optional[Dict[str, int]] = None,
as_function: bool = False,
name: Optional[str] = None,
domain: Optional[str] = None,
attributes: Optional[List[str]] = None,
constraints: Optional[Dict[Any, TensorType]] = None,
ir_version: Optional[int] = None,
) -> Union[ModelProto, FunctionProto, List[Any]]:
"""
Converts the recursive graph to ONNX.
:param target_opsets: dictionary `{opset: version}`, if None,
it is replaced by `DEFAULT_OPSETS`
:param as_function: conversion to :class:`onnx.FunctionProto`
or :class:`onnx.ModelProto`
:param name: function name if *as_function* is True
:param domain: function domain if *as_function* is True
:param attributes: function attributes if any
:param constraints: specifies a precise type for the type
constraints when a function allows more than one type,
this works if there is only one variable to be converted
:return: ModelProto, FunctionProto
"""
from .npx_graph_builder import _GraphBuilder
# Var.to_onnx
if target_opsets is None:
target_opsets = DEFAULT_OPSETS.copy()
g = _GraphBuilder(
target_opsets,
as_function=as_function,
name=name,
domain=domain,
attributes=attributes,
constraints=constraints,
ir_version=ir_version,
)
done = set()
outputs = []
for var in self.inputs:
vs = var._get_vars()
for var2 in vs:
key = id(var2)
if key in done:
continue
g.append(var2)
done.add(key)
outputs.append(vs[-1])
onx = g.to_onnx(output_vars=outputs)
if as_function:
if len(outputs) != len(onx.output):
raise RuntimeError(
f"Mismatch number of outputs, expecting {len(outputs)}, "
f"got ({len(onx.output)})."
)
if g.functions_:
return [g.functions_, onx]
return onx
if len(outputs) != len(onx.graph.output):
raise RuntimeError(
f"Mismatch number of outputs, expecting {len(outputs)}, "
f"got ({len(onx.graph.output)})."
)
return onx
[docs]
class Var(BaseArrayApi):
"""
Defines a variable, a result...
:param inputs: list of inputs
:param op: apply on operator on the inputs
:param inline: True to reduce the use of function and inline
small functions, this only applies if *op* is a function
:param n_var_outputs: number of the operator outputs
:param input_indices: to select a specific output from the input
operator
:param kwargs: operator attributes
Private attribute:
:param onnx_input_type_: names given to the variables
"""
def __array_namespace__(self, api_version: Optional[str] = None):
"""
Raises an exception if called.
"""
raise RuntimeError(
f"This function should never be called for class {type(self)}. "
f"It should be called for an eager tensor."
)
@staticmethod
def get_cst_var():
from .npx_core_api import cst, var
return cst, var
class _setter_do:
def __init__(self, parent: "Var", *args):
self.parent = parent.self_var
self.args = args
def __call__(self, new_values):
"""
Returns a copy of `self.parent` where values
whose indices are indicated by `args` and new
values by `new_values`.
"""
if len(self.args) == 1 and isinstance(self.args[0], (int, slice)):
return self._setitem1_slice(self.args[0], new_values)
if len(self.args) == 1 and isinstance(self.args[0], Var):
return self._setitem1_where(self.args[0], new_values)
raise NotImplementedError(
f"This expression is not yet implemented for args={self.args}."
)
def _setitem1_where(self, index, new_values):
cst, var = Var.get_cst_var()
if isinstance(new_values, (int, float, bool)):
new_values = np.array(new_values)
if isinstance(new_values, np.ndarray):
value = var(cst(new_values), self.parent, op="CastLike")
elif isinstance(new_values, Var):
value = new_values
else:
raise TypeError(f"Unexpected type for new_values: {type(new_values)}.")
return var(index, value, self.parent, op="Where")
def _setitem1_slice(self, index, new_values):
cst, var = Var.get_cst_var()
if isinstance(index, slice):
start = 0 if index.start is None else index.start
stop = index.stop
step = index.step
elif isinstance(index, int):
start, stop, step = index, index + 1, 1
else:
raise NotImplementedError(
f"Unable to assign new values due to "
f"unexpected type {type(index)!r}."
)
inp = self.parent
if stop is None and isinstance(new_values, np.ndarray):
stop = start + new_values.size
if stop is None:
raise NotImplementedError(f"No implementation if stop is {stop}.")
indices = np.arange(start, stop, step or 1).astype(np.int64)
if isinstance(new_values, np.ndarray):
values = new_values
else:
values = np.full(indices.shape, new_values)
return var(inp, cst(indices), cst(values), op="ScatterElements", axis=0)
class _setter:
def __init__(self, parent: "Var"):
self.parent = parent
def __getitem__(self, *args):
return Var._setter_do(self.parent, *args)
def __init__(
self,
*inputs: List[Any],
op: Optional[
Union[Callable, str, Tuple[str, str], FunctionProto, ModelProto, NodeProto]
] = None,
dtype: Optional[Union[type, DType]] = None,
inline: bool = False,
n_var_outputs: int = 1,
input_indices: Optional[List[int]] = None,
**kwargs,
):
self.inputs = list(inputs)
self.n_var_outputs = n_var_outputs
self.inline = inline
self._annotation = None
if op is None:
self.onnx_op = None # a constant
elif isinstance(op, tuple):
self.onnx_op = op # domain, operator name
elif isinstance(op, str):
self.onnx_op = ("", op) # operator name
elif isinstance(op, (FunctionProto, ModelProto, NodeProto)):
self.onnx_op = (ONNX_DOMAIN, op)
else:
self.onnx_op = (None, op) # function to call
self.onnx_op_kwargs = kwargs
self._prefix = None
if isinstance(dtype, DType):
# regular parameter
self.onnx_op_kwargs["dtype"] = dtype
elif hasattr(dtype, "type_name"):
self.dtype = dtype
elif dtype is None:
self.dtype = None
else:
raise TypeError(f"Unexpected type {type(dtype)} for dtype.")
updates = {}
for i, inp in enumerate(self.inputs):
if isinstance(inp, type):
raise TypeError(f"Unexpected type for input {i} - {inp}.")
if isinstance(inp, Var):
updates[i] = inp.self_var
if not isinstance(inp, np.ndarray):
continue
if inp.size > 0 and isinstance(inp.ravel()[0], (np.ndarray, Var)):
raise TypeError(
f"Unexpected type for input {i}: {type(inp)}, "
f"{inp.ravel()[0]}, op={op!r}"
)
# This step is needed when Var.__setitem__ was called to
# modify the variable.
for i, v in updates.items():
self.inputs[i] = v
self.inputs = tuple(self.inputs)
if input_indices is None:
self.input_indices = [0 for i in self.inputs]
elif not isinstance(input_indices, list):
raise TypeError(
f"input_indices is {type(input_indices)} "
f"but len(inputs)={len(inputs)}."
)
else:
self.input_indices = input_indices
if len(self.input_indices) != len(self.inputs):
raise RuntimeError(
f"length mismatch len(self.input_indices)="
f"{len(self.input_indices)} != len(self.inputs)="
f"{len(self.inputs)}."
)
if self.onnx_op is None:
if not isinstance(self, (Input, Cst)):
raise RuntimeError(f"This case is not allowed: {self!r}.")
self.set = Var._setter(self)
self.current_var_ = None
@property
def annotation(self):
"""Returns a type if known for the Var itself."""
if self._annotation is None:
if "dtype" in self.onnx_op_kwargs:
dtype = self.onnx_op_kwargs["dtype"]
if isinstance(dtype, DType):
return TensorType[dtype]
return self._annotation
@property
def self_var(self):
"""
Returns itself or the variable corresponding to its
state after a call to `__setitem__`.
"""
if not hasattr(self, "current_var_"):
raise AttributeError(
f"Class {type(self)} is missing attribute 'current_var_'."
)
return self if self.current_var_ is None else self.current_var_
def __call__(self):
return self.self_var
def __repr__(self) -> str:
"usual"
args = []
for inp in self.inputs:
n = inp.__class__.__name__
args.append(f"{n[0]}.")
if self.onnx_op is not None:
args.append(f"op={self.onnx_op!r}")
if self.n_var_outputs != 1:
args.append(f"n_var_outputs={self.n_var_outputs!r}")
if max(self.input_indices) != 0:
args.append(f"input_indices={self.input_indices!r}")
for k, v in sorted(self.onnx_op_kwargs.items()):
args.append(f"{k}={v!r}")
res = f"{self.__class__.__name__}({', '.join(args)})"
return res
[docs]
def set_onnx_name(self, prefix: str):
"""
Forces this variable to get this name during
:param prefix: prefix
"""
self._prefix = prefix
def _get_vars(self):
vs = []
stack = [self.self_var]
replacement = {}
replacement_cst = {}
deleted = []
while len(stack) > 0:
var = stack.pop()
key = id(var)
if key in replacement:
while key in replacement:
var = replacement[key]
key = id(var)
if var.onnx_op is not None and var.onnx_op[0] is None and var.inline:
fct = var.onnx_op[1]
applied = fct(*var.inputs, **var.onnx_op_kwargs)
if isinstance(applied, (ManyIdentity, Var)):
stack.append(applied)
replacement[id(var)] = applied
deleted.append(var)
continue
raise TypeError(
f"Unexpected type {type(applied)} as output of " f"function {fct}."
)
vs.append(var)
for i in reversed(var.inputs):
if isinstance(i, Var):
stack.insert(0, i)
continue
if isinstance(i, np.ndarray):
cst = Var.get_cst_var()[0]
replacement_cst[id(i)] = cst(i)
continue
if isinstance(i, (int, float, bool)):
cst = Var.get_cst_var()[0]
replacement_cst[id(i)] = cst(np.array(i))
continue
if isinstance(i, tuple):
if all(map(lambda x: isinstance(x, int), i)):
cst = Var.get_cst_var()[0]
replacement_cst[id(i)] = cst(np.array(list(i), dtype=np.int64))
continue
if any(map(lambda t: isinstance(t, Var), i)):
raise TypeError(
f"Unexpected types in tuple "
f"({[type(t) for t in i]}), "
f"function {self.f} from module {self.f.__module__!r}."
)
raise TypeError(
f"Unsupported tuple {i!r}, "
f"function {self.f} from module {self.f.__module__!r}."
)
if i is None:
continue
raise TypeError(
f"Unexpected type {type(i)} for an input of node {var}."
)
res = list(reversed(vs))
# replacement: a node calling a function can either
# remains as a call to a local function or the code
# of the function can replace the call inline.
# replacement keeps a map of function call to replace
# by the return itself to avoid calling the same function
# twice.
new_res = []
for r in res:
new_inputs = []
new_indices = []
repl = False
for v, ind in zip(r.inputs, r.input_indices):
key = id(v)
if key in replacement:
while key in replacement:
var = replacement[key]
key = id(var)
new_inputs.append(var)
new_indices.append(ind)
repl = True
else:
new_inputs.append(v)
new_indices.append(ind)
if repl:
new_r = r.replace_inputs(new_inputs, input_indices=new_indices)
replacement[id(r)] = new_r
new_res.append(new_r)
else:
new_res.append(r)
# check the graph is consistent
known = {}
for r in new_res:
known[id(r)] = r
if isinstance(r, (Cst, Input)):
continue
for ind, i in enumerate(r.inputs):
if i is None:
# optional input
continue
if id(i) in replacement_cst:
# constant to replace
continue
if id(i) not in known:
raise RuntimeError(
f"An input {ind} ({id(i)}, type={type(i)}) from "
f"{id(r)}-{r} is not known, it is not produced by a "
f"previous var (scheduled for replacement: "
f"{id(i) in replacement}). This also happens if "
f"a constant is not wrapped by 'cst(.)'."
)
return new_res
@property
def is_function(self):
"""
Tells if this variable encapsulate a function.
"""
return self.onnx_op is not None and self.onnx_op[0] is None
[docs]
def to_onnx(
self,
target_opsets: Optional[Dict[str, int]] = None,
as_function: bool = False,
name: Optional[str] = None,
domain: Optional[str] = None,
attributes: Optional[List[str]] = None,
constraints: Optional[Dict[Any, TensorType]] = None,
ir_version: Optional[int] = None,
) -> Union[ModelProto, FunctionProto, List[Any]]:
"""
Converts the recursive graph to ONNX.
:param target_opsets: dictionary `{opset: version}`
:param as_function: conversion to :class:`onnx.FunctionProto`
or :class:`onnx.ModelProto`
:param name: function name if *as_function* is True
:param domain: function domain if *as_function* is True
:param attributes: function attributes if any
:param constraints: specifies a precise type for the type
constraints when a function allows more than one type,
this works if there is only one variable to be converted
:return: ModelProto, FunctionProto
"""
from .npx_graph_builder import _GraphBuilder
# Var.to_onnx
if target_opsets is None:
target_opsets = DEFAULT_OPSETS
vs = self._get_vars()
g = _GraphBuilder(
target_opsets,
as_function=as_function,
name=name,
domain=domain,
attributes=attributes,
constraints=constraints,
ir_version=ir_version,
)
for var in vs:
g.append(var)
onx = g.to_onnx()
if as_function and len(g.functions_) > 0:
return [g.functions_, onx]
return onx
# Operators
def __iter__(self):
"""
The :epkg:`Array API` does not define this function (2022/12).
This method raises an exception with a better error message.
"""
raise ArrayApiError(
f"Iterators are not implemented in the generic case. "
f"Every function using them cannot be converted into ONNX "
f"(Var - {type(self)})."
)
def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var":
var = Var.get_cst_var()[1]
if isinstance(ov, (int, float, bool, np.ndarray, Cst)):
return var(self.self_var, var(ov, self.self_var, op="CastLike"), op=op_name)
return var(self.self_var, ov, op=op_name, **kwargs)
def _binary_op_right(self, ov: "Var", op_name: str, **kwargs) -> "Var":
var = Var.get_cst_var()[1]
if isinstance(ov, (int, float, bool, np.ndarray, Cst)):
return var(var(ov, self.self_var, op="CastLike"), self.self_var, op=op_name)
return var(ov, self.self_var, op=op_name, **kwargs)
def __neg__(self) -> "Var":
"""
Automatically adds operator `Neg` to the graph.
It does not cast automatically.
"""
var = Var.get_cst_var()[1]
return var(self.self_var, op="Neg")
def __invert__(self) -> "Var":
"""
Automatically adds operator `BitwiseNot` to the graph.
It does not cast automatically.
"""
var = Var.get_cst_var()[1]
return var(self.self_var, op="BitwiseNot")
def __add__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Add` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Add")
def __radd__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Add` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "Add")
def __sub__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Sub` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Sub")
def __rsub__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Sub` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "Sub")
def __mul__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Mul` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Mul")
def __rmul__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Mul` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "Mul")
def __matmul__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `MatMul` to the graph.
It does not cast automatically.
`__rmatmul__` would not be called as a numpy array
overwrites `__matmul__` on its side.
"""
return self._binary_op(ov, "MatMul")
def __truediv__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Div` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Div")
def __rtruediv__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Div` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "Div")
def __mod__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Mod` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Mod")
def __rmod__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Mod` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "Mod")
def __pow__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Pow` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Pow")
def __rpow__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Pow` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "Pow")
def __lt__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Less` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Less")
def __le__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `LessOrEqual` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "LessOrEqual")
def __gt__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Greater` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Greater")
def __ge__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `GreaterOrEqual` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "GreaterOrEqual")
def __eq__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Equal` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "Equal")
def __ne__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `Not + Equal` to the graph.
It does not cast automatically.
"""
var = Var.get_cst_var()[1]
return var(self._binary_op(ov, "Equal"), op="Not")
def __lshift__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitShift` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "BitShift", direction="LEFT")
def __rshift__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitShift` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "BitShift", direction="RIGHT")
def __and__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitwiseAnd` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "BitwiseAnd")
def __rand__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitwiseAnd` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "BitwiseAnd")
def __or__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitwiseOr` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "BitwiseOr")
def __ror__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitwiseOr` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "BitwiseOr")
def __xor__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitwiseXor` to the graph.
It does not cast automatically.
"""
return self._binary_op(ov, "BitwiseXor")
def __rxor__(self, ov: "Var") -> "Var":
"""
Automatically adds operator `BitwiseXor` to the graph.
It does not cast automatically.
"""
return self._binary_op_right(ov, "BitwiseXor")
@property
def T(self) -> "Var":
"Transpose."
var = Var.get_cst_var()[1]
return var(self.self_var, op="Transpose", perm=[1, 0])
[docs]
def astype(self, dtype) -> "Var":
"Cast"
var = Var.get_cst_var()[1]
if isinstance(dtype, Var):
return var(self.self_var, dtype, op="CastLike")
if not isinstance(dtype, DType):
raise TypeError(f"dtype cannot be {type(dtype)}.")
return var(self.self_var, op="Cast", to=dtype)
@property
def shape(self) -> "Var":
"Shape"
var = Var.get_cst_var()[1]
return var(self.self_var, op="Shape")
[docs]
def reshape(self, shape: "Var") -> "Var":
"Reshape"
cst, var = Var.get_cst_var()
if isinstance(shape, (tuple, list)):
shape = np.array(shape, dtype=np.int64)
else:
shape = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape")
return var(self.self_var, shape, op="Reshape")
[docs]
def reduce_function(
self,
reduce_op,
axis: OptTensorType[ElemType.int64, "I"] = None,
keepdims: ParType[int] = 0,
) -> "Var":
"See :func:`numpy.sum` or any other reduce function."
var = Var.get_cst_var()[1]
if axis is None:
return var(self.self_var, op=reduce_op, keepdims=keepdims)
if isinstance(axis, int):
axis = [axis]
if isinstance(axis, (tuple, list)):
cst = Var.get_cst_var()[0]
axis = cst(np.array(axis, dtype=np.int64))
return var(self.self_var, axis, op=reduce_op, keepdims=keepdims)
[docs]
def sum(
self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
) -> "Var":
"See :func:`numpy.sum`."
return self.reduce_function("ReduceSum", axis=axis, keepdims=keepdims)
[docs]
def mean(
self, axis: OptParType[TupleType[int]] = None, keepdims: ParType[int] = 0
) -> "Var":
"See :func:`numpy.mean`."
return self.reduce_function("ReduceMean", axis=axis, keepdims=keepdims)
[docs]
def min(
self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
) -> "Var":
"See :func:`numpy.min`."
return self.reduce_function("ReduceMin", axis=axis, keepdims=keepdims)
[docs]
def max(
self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
) -> "Var":
"See :func:`numpy.max`."
return self.reduce_function("ReduceMax", axis=axis, keepdims=keepdims)
[docs]
def prod(
self, axis: TensorType[ElemType.int64, "I"] = None, keepdims: ParType[int] = 0
) -> "Var":
"See :func:`numpy.prod`."
return self.reduce_function("ReduceProd", axis=axis, keepdims=keepdims)
[docs]
def copy(self) -> "Var":
"""
Returns a copy of self (use of Identity node).
"""
var = Var.get_cst_var()[1]
return var(self.self_var, op="Identity")
[docs]
def flatten(self) -> "Var":
"""
Flattens a matrix (see :meth:`numpy.ndarray.flatten`).
:param axis: only flatten from axis to the end.
:return: :class:`Var <onnx_array_api.npx.npx_var.Var>`
"""
cst, var = Var.get_cst_var()
return var(
var(self.self_var, op="Flatten", axis=0),
cst(np.array([0], dtype=np.int64)),
op="Squeeze",
)
[docs]
def get(self, index: int) -> "Var":
"""
If an operator or a function returns more than one output,
this takes only one.
:param index: index of the output to select
:return: Var
"""
if index < 0 or index >= self.n_var_outputs:
raise ValueError(
f"index={index} must be positive and < {self.n_var_outputs} "
f"for var={self!r}."
)
return Var(self.self_var, input_indices=[index], op="Identity")
def __getitem__(self, index: Any) -> "Var":
"""
Deals with multiple scenarios.
* *index* is an integer and the object produces multiple
outputs and this returns one of them (**scenario 0**)
* *index* is an integer or a slice, a tuple of integers and slices,
example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**)
* *index* is an *ONNX* object (more precisely an instance of
:class:`Var <onnx_array_api.npx.npx_var.Var>`),
then the method assumes it is an array of
boolean to select a subset of the tensor along the first axis,
example: `mat[mat == 0]` (**scenario 2**)
"""
cst, var = Var.get_cst_var()
if self.n_var_outputs != 1:
# Multioutput
if not isinstance(index, int):
raise TypeError(
f"Only indices are allowed when selecting an output, "
f"not {type(index)})."
)
return self.get(index)
if isinstance(index, Var):
# scenario 2
# we rely on the annotation if it exists
if index.annotation is None:
dtype_bool = True
elif issubclass(index.annotation, TensorType):
if index.annotation.supports_dtype(
DType(TensorProto.INT64)
) or index.annotation.supports_dtype(DType(TensorProto.INT32)):
dtype_bool = False
elif index.annotation.supports_dtype(DType(TensorProto.BOOL)):
dtype_bool = True
else:
raise TypeError(
f"Unexpected dtype for annotation={index.annotation!r} "
f"for index={index!r}."
)
else:
raise TypeError(
f"Unexpected annotation={index.annotation!r} "
f"for index={index!r}."
)
if dtype_bool:
# TODO: fix this when index is an integer and the annotation unknown
# it needs to support subgraph and tests
new_shape = cst(np.array([-1], dtype=np.int64))
new_self = self.reshape(new_shape)
new_index = index.reshape(new_shape)
return var(new_self, new_index, op="Compress")
# dtype is int
return var(self, index, axis=0, op="Gather")
if isinstance(index, int):
# Use Gather instead.
return var(self, cst(np.array(index, dtype=np.int64)), axis=0, op="Gather")
if not isinstance(index, tuple):
index = (index,)
elif not index:
# The array contains a scalar and it needs to be returned.
return var(self, op="Identity")
# only one integer?
ni = None
ax = None
for i, a in enumerate(index):
if isinstance(a, int):
if ni is None:
ni = i
ax = a
else:
ax = None
ni = None
break
if (
isinstance(a, slice)
and a.start is None
and a.stop is None
and a.step is None
):
continue
ax = None
ni = None
break
if ni is not None and ax is not None:
# Use Gather instead.
return var(self, cst(np.array(ni, dtype=np.int64)), axis=ax, op="Gather")
# scenario 1
starts = []
ends = []
axes = []
steps = []
axis_squeeze = []
needs_shape = []
for i, ind in enumerate(index):
if isinstance(ind, int):
starts.append(ind)
ends.append(ind + 1)
axes.append(i)
steps.append(1)
axis_squeeze.append(i)
continue
if isinstance(ind, slice):
if ind.start is None and ind.stop is None and ind.step is None:
continue
start = 0 if ind.start is None else ind.start
end = (None, i) if ind.stop is None else ind.stop
step = 1 if ind.step is None else ind.step
starts.append(start)
ends.append(end)
axes.append(i)
steps.append(step)
if isinstance(end, tuple):
needs_shape.append(len(ends) - 1)
elif isinstance(end, Var):
needs_shape.append(end)
continue
raise NotImplementedError(f"Not implemented for type {type(ind)!r}.")
if max(steps) == min(steps) == 1:
steps = None
else:
steps = np.array(steps, dtype=np.int64)
starts = np.array(starts, dtype=np.int64)
axes = np.array(axes, dtype=np.int64)
if needs_shape:
shape = self.shape
conc = []
for e in ends:
if isinstance(e, tuple):
conc.append(
var(shape, cst(np.array([e[1]], np.int64)), op="Gather")
)
elif isinstance(e, Var):
conc.append(e.reshape(np.array([-1], dtype=np.int64)))
else:
conc.append(np.array([e], dtype=np.int64))
if len(conc) > 1:
conc_cst = [v if isinstance(v, Var) else cst(v) for v in conc]
ends = var(*conc_cst, op="Concat", axis=0)
else:
ends = conc[0]
else:
ends = np.array(ends, dtype=np.int64)
sliced_args = [starts, ends, axes]
if steps is not None:
sliced_args.append(steps)
sliced_args_cst = [v if isinstance(v, Var) else cst(v) for v in sliced_args]
sliced = var(self.self_var, *sliced_args_cst, op="Slice")
if axis_squeeze:
return var(
sliced,
cst(np.array(axis_squeeze, dtype=np.int64)),
op="Squeeze",
)
return sliced
def __setitem__(self, index, values):
new_op = self.set[index](values)
self.current_var_ = new_op
self.input_indices = None
[docs]
class Cst(Var):
"""
Defines a constant.
"""
def __init__(self, cst: Any):
if isinstance(cst, np.ndarray):
Var.__init__(self, cst, op="Identity")
elif isinstance(cst, bool):
Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
elif isinstance(cst, int):
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
elif isinstance(cst, float):
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
elif isinstance(cst, list):
if all(map(lambda t: isinstance(t, bool), cst)):
Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity")
elif all(map(lambda t: isinstance(t, (int, bool)), cst)):
Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity")
elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)):
Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity")
else:
raise ValueError(
f"Unable to convert cst (type={type(cst)}), " f"value={cst}."
)
else:
raise NotImplementedError(
f"Constant of type {type(cst)} are not implemented yet. "
f"You should not use 'float32(x)' but 'array(x, dtype=float32)'."
)
self._prefix = "cst"