from functools import partial
from typing import List, Optional, Union
[docs]
class Var:
"""
Traceable variable name.
"""
def __init__(
self,
name: str,
builder: Optional["GraphBuilder"] = None, # noqa: F821
):
assert isinstance(name, str), f"Unexpected type {type(name)} for name"
self.name = name
self.builder = builder
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r})"
def _raise(self, name):
if self.builder is None:
raise RuntimeError(
f"The function being traced required to executed "
f"with real Tensor. {self!r} cannot be evaluated with method {name!r}."
)
raise RuntimeError(
f"The function being traced required to executed "
f"with real Tensor. {self!r} cannot be evaluated with method {name!r}."
f"{self.builder.get_debug_msg()}"
)
def __eq__(self, _) -> str:
return self._raise("__eq__")
def __lt__(self, _) -> str:
return self._raise("__lt__")
def __gt__(self, _) -> str:
return self._raise("__gt__")
def __le__(self, _) -> str:
return self._raise("__le__")
def __ge__(self, _) -> str:
return self._raise("__ge__")
def __int__(self, _) -> str:
return self._raise("__int__")
def __getitem__(self, _) -> str:
return self._raise("__getitem__")
def __len__(self, _) -> str:
return self._raise("__len__")
[docs]
class OxsOpset:
"""
Bridge with :epkg:`onnxscript`.
:param builder: builder
"""
# defined for opset >= 18
# name: number of expected outputs
_implemented = {
"Abs": 1,
"Add": 1,
"And": 1,
"ArgMax": 1,
"ArgMin": 1,
"Cast": 1,
"CastLike": 1,
"Celu": 1,
"Concat": 1,
"Constant": 1,
"ConstantOfShape": 1,
"Cos": 1,
"Div": 1,
"Dropout": 2,
"Elu": 1,
"Equal": 1,
"Exp": 1,
"Expand": 1,
"Flatten": 1,
"Gather": 1,
"GatherElements": 1,
"GatherND": 1,
"Gemm": 1,
"Greater": 1,
"GreaterOrEqual": 1,
"Identity": 1,
"MatMul": 1,
"MaxPool": 2,
"Mul": 1,
"Less": 1,
"LessOrEqual": 1,
"Log": 1,
"LogSoftmax": 1,
"Neg": 1,
"Not": 1,
"Or": 1,
"Pow": 1,
"Range": 1,
"Reciprocal": 1,
"ReduceMax": 1,
"ReduceMean": 1,
"ReduceMin": 1,
"ReduceSum": 1,
"Relu": 1,
"Reshape": 1,
"ScatterElements": 1,
"ScatterND": 1,
"Shape": 1,
"Sigmoid": 1,
"Sin": 1,
"Size": 1,
"Slice": 1,
"Softmax": 1,
"Sqrt": 1,
"Squeeze": 1,
"Sub": 1,
"Tile": 1,
"Transpose": 1,
"Unsqueeze": 1,
"Where": 1,
}
def __init__(self, builder: "GraphBuilder"): # noqa: F821
self.builder = builder
def __getattr__(self, name):
if name in self._implemented:
return partial(self.make_node, name)
try:
return super().__getattr__(name)
except AttributeError as e:
raise AttributeError(
f"Unable to access attribute {name!r}, "
f"you can still use this operator with method 'make_node'."
) from e
def IsScalar(self, name: str) -> bool:
name = name if isinstance(name, str) else name.name
if self.builder.has_shape(name):
shape = self.builder.get_shape(name)
return shape in (tuple(), (1,))
if self.builder.has_rank(name):
rank = self.builder.get_rank(name)
if rank == 0:
return True
raise RuntimeError(
f"Unable to tell if {name!r} is scalar{self.builder.get_debug_msg()}"
)
def Rank(self, name: str) -> int:
name = name if isinstance(name, str) else name.name
assert self.builder.has_rank(
name
), f"Rank is missing for name={name!r}{self.builder.get_debug_msg()}"
return self.builder.get_rank(name)
[docs]
def make_node(
self,
op_type: str,
*inputs: Optional[Union[str, List[str]]],
outputs: Optional[Union[int, List[str], str]] = None,
domain: str = "",
name: Optional[str] = None,
**kwargs,
):
"""
Creates a node.
:param op_type: type
:param inputs: inputs
:param outputs: outputs
:param domain: domain
:param name: name
:param kwargs: additional arguments
:return: output name
"""
assert not op_type.startswith("Reduce") or self.builder.main_opset >= 18, (
f"Reduce operator {op_type!r} is not tested for opset < 18"
f"{self.builder.get_debug_msg()}"
)
if name is None:
# This node is created by an externel converting library.
name = f"{self.__class__.__name__}"
if outputs is None:
outputs = self._implemented[op_type]
if inputs is None:
inputs = []
new_inputs = []
for i in inputs:
assert not isinstance(
i, (list, tuple)
), f"Wrong inputs for operator {op_type!r}: {inputs!r}"
if isinstance(i, str):
new_inputs.append(i)
elif hasattr(i, "name"):
# torch.fx.Node
new_inputs.append(i.name)
else:
cst_name = self.builder.make_initializer(
"", i, msg=f"input {i} of op_type={op_type!r}", source="OxsOpset.cst"
)
new_inputs.append(cst_name)
outputs = self.builder.make_node(
op_type, new_inputs, outputs=outputs, domain=domain, name=name, **kwargs
)
if isinstance(outputs, tuple):
return tuple(Var(o, self.builder) for o in outputs)
return Var(outputs, self.builder)