Source code for onnx_array_api.light_api.var

import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import TensorProto
from onnx.defs import get_schema
from ..annotations import (
    elem_type_int,
    make_shape,
    ELEMENT_TYPE,
    ELEMENT_TYPE_NAME,
    GRAPH_PROTO,
    SHAPE_TYPE,
    VAR_CONSTANT_TYPE,
)
from .model import OnnxGraph
from ._op_var import OpsVar
from ._op_vars import OpsVars


[docs] class SubDomain: """ Declares a domain or a piece of it (if it contains '.' in its name). """ def __init__(self, var: "BaseVar"): if not isinstance(var, BaseVar): raise TypeError(f"Unexpected type {type(var)}.") self.parent = var
def _getclassattr_(self, name): if not hasattr(self.__class__, name): raise TypeError( f"Unable to find {name!r} in class {self.__class__.__name__!r}, " f"available {dir(self.__class__)}." ) return getattr(self.__class__, name)
[docs] class BaseVar: """ Represents an input, an initializer, a node, an output, multiple variables. :param parent: the graph containing the Variable """ def __new__(cls, *args, **kwargs): """ If called for the first instantiation of a BaseVar, it process all methods declared with decorator :func:`onnx_array_api.light_api.domain` so that it can be called with a syntax `v.<domain>.<op_type>`. """ res = super().__new__(cls) res.__init__(*args, **kwargs) if getattr(cls, "__incomplete", True): for k in dir(cls): att = getattr(cls, k, None) if not att: continue name = getattr(att, "__name__", None) if not name or name[0] != "[": continue # A function with a domain name if not inspect.isfunction(att): raise RuntimeError(f"{cls.__name__}.{k} is not a function.") domain, op_type = name[1:].split("]") if "." in domain: spl = domain.split(".", maxsplit=1) dname = f"_{spl[0]}" if not hasattr(cls, dname): d = type( f"{cls.__name__}{dname}", (SubDomain,), {"name": dname[1:]} ) setattr(cls, dname, d) setattr( cls, spl[0], property( lambda self, _name_=dname: _getclassattr_(self, _name_)( self ) ), ) else: d = getattr(cls, dname) suffix = spl[0] for p in spl[1].split("."): dname = f"_{p}" suffix += dname if not hasattr(d, dname): sd = type( f"{cls.__name__}_{suffix}", (SubDomain,), {"name": suffix}, ) setattr(d, dname, sd) setattr( d, p, property( lambda self, _name_=dname: _getclassattr_( self, _name_ )(self.parent) ), ) d = sd else: d = getattr(d, dname) elif not hasattr(cls, domain): dname = f"_{domain}" d = type(f"{cls.__name__}{dname}", (SubDomain,), {"name": domain}) setattr(cls, dname, d) setattr( cls, domain, property( lambda self, _name_=dname: _getclassattr_(self, _name_)( self ) ), ) setattr(d, op_type, att) setattr(cls, "__incomplete", False) return res def __init__( self, parent: OnnxGraph, ): if not isinstance(parent, OnnxGraph): raise RuntimeError(f"Unexpected parent type {type(parent)}.") self.parent = parent
[docs] def make_node( self, op_type: str, *inputs: List[VAR_CONSTANT_TYPE], domain: str = "", n_outputs: int = 1, output_names: Optional[List[str]] = None, **kwargs: Dict[str, Any], ) -> Union["Var", "Vars"]: """ Creates a node with this Var as the first input. :param op_type: operator type :param inputs: others inputs :param domain: domain :param n_outputs: number of outputs :param output_names: output names, if not specified, outputs are given unique names :param kwargs: node attributes :return: instance of :class:`onnx_array_api.light_api.Var` or :class:`onnx_array_api.light_api.Vars` """ if domain in ("", "ai.onnx.ml"): if self.parent.opset is None: schema = get_schema(op_type, domain) else: schema = get_schema(op_type, self.parent.opset, domain) if n_outputs < schema.min_output or n_outputs > schema.max_output: raise RuntimeError( f"Unexpected number of outputs ({n_outputs}) " f"for node type {op_type!r}, domain={domain!r}, " f"version={self.parent.opset}, it should be in " f"[{schema.min_output}, {schema.max_output}]." ) n_inputs = len(inputs) if n_inputs < schema.min_input or n_inputs > schema.max_input: raise RuntimeError( f"Unexpected number of inputs ({n_inputs}) " f"for node type {op_type!r}, domain={domain!r}, " f"version={self.parent.opset}, it should be in " f"[{schema.min_input}, {schema.max_input}]." ) node_proto = self.parent.make_node( op_type, *inputs, domain=domain, n_outputs=n_outputs, output_names=output_names, **kwargs, ) names = node_proto.output if n_outputs is not None and len(node_proto.output) != len(names): raise RuntimeError( f"Expects {n_outputs} outputs but output names are {names}." ) if len(names) == 1: return Var(self.parent, names[0]) return Vars(self.parent, *list(map(lambda v: Var(self.parent, v), names)))
[docs] def vin( self, name: str, elem_type: ELEMENT_TYPE = TensorProto.FLOAT, shape: Optional[SHAPE_TYPE] = None, ) -> "Var": """ Declares a new input to the graph. :param name: input name :param elem_type: element_type :param shape: shape :return: instance of :class:`onnx_array_api.light_api.Var` """ return self.parent.vin(name, elem_type=elem_type, shape=shape)
[docs] def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var": """ Adds an initializer :param value: constant tensor :param name: input name :return: instance of :class:`onnx_array_api.light_api.Var` """ c = self.parent.make_constant(value, name=name) return Var(self.parent, c.name, elem_type=c.data_type, shape=tuple(c.dims))
[docs] def v(self, name: str) -> "Var": """ Retrieves another variable than this one. :param name: name of the variable :return: instance of :class:`onnx_array_api.light_api.Var` """ return self.parent.get_var(name)
[docs] def bring(self, *vars: List[Union[str, "Var"]]) -> Union["Var", "Vars"]: """ Creates a set of variable as an instance of :class:`onnx_array_api.light_api.Vars`. """ if len(vars) == 1: return Var(self.parent, vars[0]) return Vars(self.parent, *vars)
[docs] def vout(self, **kwargs: Dict[str, Any]) -> Union["Var", "Vars"]: """ This method needs to be overwritten for Var and Vars depending on the number of variable to declare as outputs. """ raise RuntimeError(f"The method was not overwritten in class {type(self)}.")
[docs] def left_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ Creates a set of variables as an instance of :class:`onnx_array_api.light_api.Vars`. `*vars` is added to the left, `self` is added to the right. """ vs = [*vars, self] return Vars(self.parent, *vs)
[docs] def right_bring(self, *vars: List[Union[str, "Var"]]) -> "Vars": """ Creates a set of variables as an instance of :class:`onnx_array_api.light_api.Vars`. `*vars` is added to the right, `self` is added to the left. """ vs = [self, *vars] return Vars(self.parent, *vs)
[docs] def to_onnx(self) -> GRAPH_PROTO: "Creates the onnx graph." return self.parent.to_onnx()
[docs] class Var(BaseVar, OpsVar): """ Represents an input, an initializer, a node, an output. :param parent: graph the variable belongs to :param name: input name :param elem_type: element_type :param shape: shape """ def __init__( self, parent: OnnxGraph, name: str, elem_type: Optional[ELEMENT_TYPE] = 1, shape: Optional[SHAPE_TYPE] = None, ): BaseVar.__init__(self, parent) self.name_ = name self.elem_type = elem_type self.shape = shape @property def name(self): "Returns the name of the variable or the new name if it was renamed." return self.parent.true_name(self.name_) def __str__(self) -> str: "usual" s = f"{self.name}" if self.elem_type is None: return s s = f"{s}:{ELEMENT_TYPE_NAME[self.elem_type]}" if self.shape is None: return s return f"{s}:[{''.join(map(str, self.shape))}]"
[docs] def vout( self, elem_type: ELEMENT_TYPE = TensorProto.FLOAT, shape: Optional[SHAPE_TYPE] = None, ) -> "Var": """ Declares a new output to the graph. :param elem_type: element_type :param shape: shape :return: instance of :class:`onnx_array_api.light_api.Var` If the checker fails, try `shape=[]`. """ output = self.parent.make_output(self.name, elem_type=elem_type, shape=shape) return Var( self.parent, output, elem_type=output.type.tensor_type.elem_type, shape=make_shape(output.type.tensor_type.shape), )
[docs] def rename(self, new_name: str) -> "Var": "Renames a variable." self.parent.rename(self.name, new_name) return self
[docs] def to(self, to: ELEMENT_TYPE) -> "Var": "Casts a tensor into another element type." return self.Cast(to=elem_type_int(to))
[docs] def astype(self, to: ELEMENT_TYPE) -> "Var": "Casts a tensor into another element type." return self.Cast(to=elem_type_int(to))
[docs] def reshape(self, new_shape: VAR_CONSTANT_TYPE) -> "Var": "Reshapes a variable." if isinstance(new_shape, tuple): cst = self.cst(np.array(new_shape, dtype=np.int64)) return self.bring(self, cst).Reshape() return self.bring(self, new_shape).Reshape()
def __add__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Add() def __eq__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Equal() def __float__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Cast(to=TensorProto.FLOAT) def __gt__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Greater() def __ge__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).GreaterOrEqual() def __int__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Cast(to=TensorProto.INT64) def __lt__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Less() def __le__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).LessOrEqual() def __matmul__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).MatMul() def __mod__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Mod() def __mul__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Mul() def __ne__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Equal().Not() def __neg__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.Neg() def __pow__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Pow() def __sub__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Sub() def __truediv__(self, var: VAR_CONSTANT_TYPE) -> "Var": "Intuitive." return self.bring(self, var).Div()
[docs] class Vars(BaseVar, OpsVars): """ Represents multiple Var. :param parent: graph the variable belongs to :param vars: list of names or variables """ def __init__(self, parent, *vars: List[Union[str, Var]]): BaseVar.__init__(self, parent) self.vars_ = [] for v in vars: if isinstance(v, str): var = self.parent.get_var(v) else: var = v self.vars_.append(var) def __len__(self): "Returns the number of variables." return len(self.vars_) def _check_nin(self, n_inputs): if len(self) != n_inputs: raise RuntimeError(f"Expecting {n_inputs} inputs not {len(self)}.") return self
[docs] def rename(self, *new_names: List[str]) -> "Vars": "Renames variables." if len(new_names) != len(self): raise ValueError( f"Vars has {len(self)} elements but the method received {len(new_names)} names." ) new_vars = [] for var, name in zip(self.vars_, new_names): new_vars.append(var.rename(name)) return Vars(self.parent, *new_names)
[docs] def vout( self, *elem_type_shape: List[ Union[ELEMENT_TYPE, Tuple[ELEMENT_TYPE, Optional[SHAPE_TYPE]]] ], ) -> "Vars": """ Declares a new output to the graph. :param elem_type_shape: list of tuple(element_type, shape) :return: instance of :class:`onnx_array_api.light_api.Vars` If the checker fails, try `shape=[]`. """ vars = [] for i, v in enumerate(self.vars_): if i < len(elem_type_shape): if isinstance(elem_type_shape[i]) or len(elem_type_shape[i]) < 2: elem_type = elem_type_shape[i][0] shape = None else: elem_type, shape = elem_type_shape[i] else: elem_type = TensorProto.FLOAT shape = None vars.append(v.vout(elem_type=elem_type, shape=shape)) return Vars(self.parent, *vars)