Source code for experimental_experiment.xbuilder.graph_builder_opset

from functools import partial
from typing import List, Optional, Union
import numpy as np


[docs] class Opset: """ Makes it easier to write onnx graph. The method name is the node type. :param graph_builder: the builder :param allow_unknown: allows unknown operators, otherwise, fails this class does not the expected number of outputs """ # defined for opset >= 18 # name: number of expected outputs _implemented = { "Abs": 1, "Add": 1, "And": 1, "ArgMax": 1, "ArgMin": 1, "Cast": 1, "CastLike": 1, "Celu": 1, "Concat": 1, "Constant": 1, "ConstantOfShape": 1, "Cos": 1, "Cosh": 1, "Div": 1, "Dropout": 2, "Elu": 1, "Equal": 1, "Exp": 1, "Expand": 1, "Flatten": 1, "Gather": 1, "GatherElements": 1, "GatherND": 1, "Gemm": 1, "Greater": 1, "GreaterOrEqual": 1, "Identity": 1, "MatMul": 1, "MaxPool": 2, "Mul": 1, "Less": 1, "LessOrEqual": 1, "Log": 1, "LogSoftmax": 1, "Neg": 1, "Not": 1, "Or": 1, "Pow": 1, "Range": 1, "Reciprocal": 1, "ReduceMax": 1, "ReduceMean": 1, "ReduceMin": 1, "ReduceSum": 1, "Relu": 1, "Reshape": 1, "ScatterElements": 1, "ScatterND": 1, "Shape": 1, "Sigmoid": 1, "Sin": 1, "Sinh": 1, "Slice": 1, "Softmax": 1, "Sqrt": 1, "Squeeze": 1, "Sub": 1, "Tile": 1, "Transpose": 1, "Trilu": 1, "Unsqueeze": 1, "Where": 1, } def __init__( self, builder: "GraphBuilder", # noqa: F821 allow_unknown: bool = False, ): self.builder = builder self.allow_unknown = allow_unknown def __getattr__(self, name): if name in self._implemented: return partial(self.make_node, name) if name in self.__dict__: return self.__dict__[name] return partial(self._make_node, name) def _make_node(self, op_type, *args, outputs=None, **kwargs): if outputs is None: if op_type in self._implemented: outputs = self._implemented[op_type] elif op_type == "Split" and kwargs.get("domain", "") == "": assert "num_outputs" in kwargs, ( "Number of outputs is not implemented yet for operator " f"{op_type!r} and kwargs={kwargs}" ) outputs = kwargs["num_outputs"] else: # We assume there is only one outputs. outputs = 1 return self.make_node(op_type, *args, outputs=outputs, **kwargs) def make_node( self, op_type: str, *inputs: Optional[Union[str, List[str]]], outputs: Optional[Union[int, List[str], str]] = None, domain: str = "", name: Optional[str] = None, **kwargs, ): assert ( op_type != "Split" or outputs != 1 ), f"Operator Split is useless with one output, inputs={inputs}, outputs={outputs}" if outputs is None: outputs = self._implemented[op_type] if inputs is None: inputs = [] assert ( op_type != "Reshape" or len(inputs) != 2 or not isinstance(inputs[1], np.ndarray) or inputs[1].dtype == np.int64 ), f"Suspicious shape {inputs[1]!r} for a Reshape{self.builder.get_debug_msg()}" new_inputs = [] for i in inputs: assert not isinstance( i, (list, tuple) ), f"Wrong inputs for operator {op_type!r}: {inputs!r}" if isinstance(i, str): new_inputs.append(i) elif hasattr(i, "name") and not hasattr(i, "detach"): # torch.fx.Node assert i.name is not None, f"Unexpected name for type {type(i)}" new_inputs.append(i.name) elif i is None: # Optional input new_inputs.append("") elif isinstance(i, np.ndarray): assert 0 not in i.shape, ( f"Not implemented for type(i)={type(i)}, i={i}, " f"inputs={inputs!r}, op_type={op_type!r}, i.shape={i.shape}" f"{self.builder.get_debug_msg()}" ) if i.dtype == np.int64 and i.size < 16: source = "Opset.make_node.1/Shape" elif i.size < 2: source = "Opset.make_node.1/Small" else: source = "Opset.make_node.0" cst_name = self.builder.make_initializer( "", i, msg=f"input {i} of op_type={op_type!r}", source=source ) new_inputs.append(cst_name) else: raise AssertionError( f"Not implemented for type(i)={type(i)}, i={i}, " f"inputs={inputs!r}, op_type={op_type!r}{self.builder.get_debug_msg()}" ) assert None not in new_inputs return self.builder.make_node( op_type, new_inputs, outputs=outputs, domain=domain, name=name or f"{self.__class__.__name__}", **kwargs, ) @staticmethod def _iaxes(op_type, axes) -> int: if isinstance(axes, np.ndarray): iaxes = axes.tolist() elif isinstance(axes, int): iaxes = [axes] else: raise RuntimeError(f"Unable to call {op_type} on a dynamic input axis={axes}") return iaxes def ReduceMaxAnyOpset(self, *args, **kwargs): if len(args) == 1: return self.ReduceMax(*args, **kwargs) assert len(args) == 2, f"ReduceMaxAnyOpset expects 2 arguments not {len(args)}" if self.builder.main_opset >= 18: return self.ReduceMax(*args, **kwargs) return self.ReduceMax(args[0], axes=self._iaxes("ReduceMax", args[1]), **kwargs) def ReduceMeanAnyOpset(self, *args, **kwargs): if len(args) == 1: return self.ReduceMean(*args, **kwargs) assert len(args) == 2, f"ReduceMeanAnyOpset expects 2 arguments not {len(args)}" if self.builder.main_opset >= 18: return self.ReduceMean(*args, **kwargs) return self.ReduceMean(args[0], axes=self._iaxes("ReduceMean", args[1]), **kwargs) def ReduceSumAnyOpset(self, *args, **kwargs): if len(args) == 1: return self.ReduceSum(*args, **kwargs) assert len(args) == 2, f"ReduceSumAnyOpset expects 2 arguments not {len(args)}" if self.builder.main_opset >= 13: return self.ReduceSum(*args, **kwargs) return self.ReduceSum(args[0], axes=self._iaxes("ReduceSum", args[1]), **kwargs) def SqueezeAnyOpset(self, *args, **kwargs): if len(args) == 1 and len(kwargs) == 0: return self.Squeeze(*args) assert len(args) == 2, f"SqueezeAnyOpset expects 2 arguments not {len(args)}" if self.builder.main_opset >= 13: return self.Squeeze(*args, **kwargs) return self.Squeeze(args[0], axes=self._iaxes("Squeeze", args[1]), **kwargs) def UnsqueezeAnyOpset(self, *args, **kwargs): if len(args) == 1 and len(kwargs) == 0: return self.Unsqueeze(*args) assert len(args) == 2, f"UnsqueezeAnyOpset expects 2 arguments not {len(args)}" if self.builder.main_opset >= 13: return self.Unsqueeze(*args, **kwargs) return self.Unsqueeze(args[0], axes=self._iaxes("Unsqueeze", args[1]), **kwargs)