Source code for onnx_array_api.light_api.var

from typing import Any, Dict, List, Optional, Union
import numpy as np
from onnx import TensorProto
from .annotations import (
    elem_type_int,
    make_shape,
    ELEMENT_TYPE,
    ELEMENT_TYPE_NAME,
    GRAPH_PROTO,
    SHAPE_TYPE,
    VAR_CONSTANT_TYPE,
)
from .model import OnnxGraph
from ._op_var import OpsVar
from ._op_vars import OpsVars


[docs]class BaseVar: """ Represents an input, an initializer, a node, an output, multiple variables. :param parent: the graph containing the Variable """ def __init__( self, parent: OnnxGraph, ): self.parent = parent
[docs] def make_node( self, op_type: str, *inputs: List[VAR_CONSTANT_TYPE], domain: str = "", n_outputs: int = 1, output_names: Optional[List[str]] = None, **kwargs: Dict[str, Any], ) -> Union["Var", "Vars"]: """ Creates a node with this Var as the first input. :param op_type: operator type :param inputs: others inputs :param domain: domain :param n_outputs: number of outputs :param output_names: output names, if not specified, outputs are given unique names :param kwargs: node attributes :return: instance of :class:`onnx_array_api.light_api.Var` or :class:`onnx_array_api.light_api.Vars` """ node_proto = self.parent.make_node( op_type, *inputs, domain=domain, n_outputs=n_outputs, output_names=output_names, **kwargs, ) names = node_proto.output if len(names) == 1: return Var(self.parent, names[0]) return Vars(*map(lambda v: Var(self.parent, v), names))
[docs] def vin( self, name: str, elem_type: ELEMENT_TYPE = TensorProto.FLOAT, shape: Optional[SHAPE_TYPE] = None, ) -> "Var": """ Declares a new input to the graph. :param name: input name :param elem_type: element_type :param shape: shape :return: instance of :class:`onnx_array_api.light_api.Var` """ return self.parent.vin(name, elem_type=elem_type, shape=shape)
[docs] def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var": """ Adds an initializer :param value: constant tensor :param name: input name :return: instance of :class:`onnx_array_api.light_api.Var` """ c = self.parent.make_constant(value, name=name) return Var(self.parent, c.name, elem_type=c.data_type, shape=tuple(c.dims))
[docs] def vout( self, elem_type: ELEMENT_TYPE = TensorProto.FLOAT, shape: Optional[SHAPE_TYPE] = None, ) -> "Var": """ Declares a new output to the graph. :param elem_type: element_type :param shape: shape :return: instance of :class:`onnx_array_api.light_api.Var` """ output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape) return Var( self.parent, output, elem_type=output.type.tensor_type.elem_type, shape=make_shape(output.type.tensor_type.shape), )
[docs] def v(self, name: str) -> "Var": """ Retrieves another variable than this one. :param name: name of the variable :return: instance of :class:`onnx_array_api.light_api.Var` """ return self.parent.get_var(name)
[docs] def bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ Creates a set of variable as an instance of :class:`onnx_array_api.light_api.Vars`. """ return Vars(self.parent, *vars)
[docs] def left_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ Creates a set of variables as an instance of :class:`onnx_array_api.light_api.Vars`. `*vars` is added to the left, `self` is added to the right. """ vs = [*vars, self] return Vars(self.parent, *vs)
[docs] def right_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ Creates a set of variables as an instance of :class:`onnx_array_api.light_api.Vars`. `*vars` is added to the right, `self` is added to the left. """ vs = [self, *vars] return Vars(self.parent, *vs)
[docs] def to_onnx(self) -> GRAPH_PROTO: "Creates the onnx graph." return self.parent.to_onnx()
[docs]class Var(BaseVar, OpsVar): """ Represents an input, an initializer, a node, an output. :param parent: graph the variable belongs to :param name: input name :param elem_type: element_type :param shape: shape """ def __init__( self, parent: OnnxGraph, name: str, elem_type: Optional[ELEMENT_TYPE] = 1, shape: Optional[SHAPE_TYPE] = None, ): BaseVar.__init__(self, parent) self.name_ = name self.elem_type = elem_type self.shape = shape @property def name(self): "Returns the name of the variable or the new name if it was renamed." return self.parent.true_name(self.name_) def __str__(self) -> str: "usual" s = f"{self.name}" if self.elem_type is None: return s s = f"{s}:{ELEMENT_TYPE_NAME[self.elem_type]}" if self.shape is None: return s return f"{s}:[{''.join(map(str, self.shape))}]"
[docs] def rename(self, new_name: str) -> "Var": "Renames a variable." self.parent.rename(self.name, new_name) return self
[docs] def to(self, to: ELEMENT_TYPE) -> "Var": "Casts a tensor into another element type." return self.Cast(to=elem_type_int(to))
[docs] def astype(self, to: ELEMENT_TYPE) -> "Var": "Casts a tensor into another element type." return self.Cast(to=elem_type_int(to))
[docs] def reshape(self, new_shape: VAR_CONSTANT_TYPE) -> "Var": "Reshapes a variable." if isinstance(new_shape, tuple): cst = self.cst(np.array(new_shape, dtype=np.int64)) return self.bring(self, cst).Reshape() return self.bring(self, new_shape).Reshape()
def __add__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Add() def __eq__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Equal() def __float__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Cast(to=TensorProto.FLOAT) def __gt__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Greater() def __ge__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).GreaterOrEqual() def __int__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Cast(to=TensorProto.INT64) def __lt__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Less() def __le__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).LessOrEqual() def __matmul__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).MatMul() def __mod__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Mod() def __mul__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Mul() def __ne__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Equal().Not() def __neg__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.Neg() def __pow__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Pow() def __sub__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Sub() def __truediv__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Div()
[docs]class Vars(BaseVar, OpsVars): """ Represents multiple Var. :param parent: graph the variable belongs to :param vars: list of names or variables """ def __init__(self, parent, *vars: List[Union[str, Var]]): BaseVar.__init__(self, parent) self.vars_ = [] for v in vars: if isinstance(v, str): var = self.parent.get_var(v) else: var = v self.vars_.append(var) def __len__(self): "Returns the number of variables." return len(self.vars_) def _check_nin(self, n_inputs): if len(self) != n_inputs: raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.") return self
[docs] def rename(self, new_name: str) -> "Var": "Renames variables." raise NotImplementedError("Not yet implemented.")