from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy
from onnx import helper, numpy_helper, ModelProto, NodeProto, TensorProto
from .einsum_config import (
DEFAULT_IR_VERSION,
DEFAULT_OPSET,
guess_proto_dtype,
)
from .einsum_impl_ext import (
numpy_extended_dot,
numpy_diagonal,
_numpy_extended_dot_equation,
numpy_extended_dot_python,
numpy_extended_dot_matrix,
)
from .blas_lapack import gemm_dot
def single_axes(axes: Tuple[int, ...]) -> Optional[List[int]]:
"""
*axes* contains positive values, then it is the position
of this axis in the original matrix, otherwise it is -1
meaning this axis is an added single dimension to align
all the dimensions based on the einsum equation.
:param axes: axes described above
:return: list of integer in set `{1, 2}`, 1 for
a single axis, 2 otherwise
"""
if axes is None:
return axes
return [(1 if a == -1 else 2) for a in axes]
[docs]
class EinsumSubOp:
"""
Defines a sub operation used in Einsum decomposition.
:param full_dim: dimension of the result
:param name: name (reshape, transpose, reduce_sum, matmul, id,
squeeze, diagonal, mul, batch_dot)
:param inputs: inputs
:param kwargs: arguments
Operator suffixed by `_mm` (*transpose_mm*, *reduce_sum_mm*)
are equivalent to the same operator without the suffix
but takes two inputs and only changes the first one.
Attributes `_info` summarizes the known information
about dimensions. Many of them are empty because inserted.
Value `1` means it was the case, `2` means it is a plain dimension.
"""
_allowed = {
"expand_dims",
"transpose",
"reduce_sum",
"matmul",
"id",
"squeeze",
"diagonal",
"mul",
"batch_dot",
"transpose_mm",
"reduce_sum_mm",
}
def __init__(
self,
full_dim: int,
name: str,
*inputs: List["EinsumSubOp"],
**kwargs: Dict[str, Any],
):
self.full_dim = full_dim
self.name = name
self.inputs = inputs
self.kwargs = kwargs
self._info: Dict[str, Any] = {}
assert (
name in EinsumSubOp._allowed
), f"Unexpected name {name!r}. It should be in {EinsumSubOp._allowed!r}."
assert len(inputs) in (
1,
2,
), f"Inputs must contains 1 or 2 inputs not {len(inputs)}."
if name == "matmul" and len(inputs) != 2:
raise RuntimeError(
"Inputs must contains 2 inputs not %d for operator 'matmul'."
"" % len(inputs)
)
for i, inp in enumerate(inputs):
assert isinstance(
inp, (int, EinsumSubOp)
), "Input %d has type %r, int or EinsumSubOp is expected." "" % (
i,
type(inp),
)
self._check_()
def _check_(self):
if self.name == "transpose":
self._check_arg_("perm", tuple)
perm = self.kwargs["perm"]
assert len(perm) == len(
set(perm)
), f"perm has duplicated values {perm!r} (name={self.name!r})."
assert list(perm) != list(
range(len(perm))
), f"Transpose = identity perm={perm}. It must be removed."
elif self.name == "matmul":
self._check_arg_("axes", tuple)
self._check_arg_("left", tuple)
self._check_arg_("right", tuple)
axes = self.kwargs["axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
for a in axes:
if a in left and a in right:
raise RuntimeError(
"One axis belongs to every set (axes, left, right). "
"axes=%r, left=%r, right=%r." % (axes, left, right)
)
def __repr__(self) -> str:
inps = ", ".join(map(str, self.inputs))
kw = ", ".join(f"{k}={w!r}" for k, w in self.kwargs.items())
m = f"{self.__class__.__name__}({self.name!r}, {inps}, {kw})"
return m
[docs]
def dot_label(self) -> Optional[str]:
"""
Displays some informations useful to understand the operator.
"""
if self.name == "matmul":
ndim = self.kwargs["ndim"]
axes = self.kwargs["axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
eq = _numpy_extended_dot_equation(ndim, ndim, axes, left, right)
eq = eq.replace(">", "\\\\=")
return "~" + eq
return None
def _check_arg_(self, name: str, typ: type, empty: bool = False):
assert (
name in self.kwargs
), f"Parameter {name!r} not found for operator {self.name!r}."
if empty and self.kwargs[name] is None:
return
assert isinstance(
self.kwargs[name], typ
), "Unexpected type %r for parameter %r and parameter %r." "" % (
type(self.kwargs[name]),
name,
self.name,
)
def _check_row_(
self,
row: numpy.ndarray,
inp: bool = False,
verbose: bool = False,
):
"""
Checks input or output is valid.
"""
if verbose:
if inp:
print("<<" if inp else ">>", self.name, row, self.kwargs)
else:
print("<<" if inp else ">>", self.name, row)
def _compute_output_row_id(
self,
row: numpy.ndarray,
row2: numpy.ndarray,
ab: bool = False,
verbose: bool = False,
):
assert not ab, "ab option not allowed."
self._check_row_(row, True, verbose=verbose)
row[:] = row2[:]
self._check_row_(row, verbose=verbose)
def _compute_output_row_transpose(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
if ab:
assert row2 is not None
self._compute_output_row_transpose(row2, verbose=verbose)
return
self._check_row_(row, True, verbose=verbose)
self._check_arg_("perm", tuple)
assert len(self.kwargs["perm"]) == len(
row
), f"Unexpected permutation {self.kwargs['perm']!r} (row={row!r})."
perm = self.kwargs["perm"]
cpy = row.copy()
for i, p in enumerate(perm):
row[i] = cpy[p]
self._check_row_(row, verbose=verbose)
def _compute_output_row_transpose_mm(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert ab, "ab must be True."
self._check_row_(row, True, verbose=verbose)
assert row2 is not None, "transpose_mm expects a second input."
self._compute_output_row_transpose(row, row2=None, verbose=verbose)
def _compute_output_row_expand_dims(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert not ab, "ab option not allowed."
self._check_row_(row, True, verbose=verbose)
self._check_arg_("axes", tuple)
axes = self.kwargs["axes"]
for axis in axes:
assert isinstance(axis, tuple), (
"Parameter axes of expand_dims should be a tuple of "
"tuple, axes=%r." % axes
)
assert row[axis[1]] == -1, "Dimension should be -1 in row %r axis=%r." % (
row,
self.kwargs["axis"],
)
self._check_row_(row, verbose=verbose)
def _compute_output_row_reduce_sum(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert not ab, "ab option not allowed."
self._check_row_(row, True, verbose=verbose)
self._check_arg_("axes", tuple)
for a in self.kwargs["axes"]:
row[a] = -1
self._check_row_(row, verbose=verbose)
def _compute_output_row_reduce_sum_mm(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert ab, "ab must be true."
self._check_row_(row2, True, verbose=verbose)
assert row2 is not None, "reduce_sum_mm expects a second input."
self._compute_output_row_reduce_sum(row, row2=None, verbose=verbose)
def _compute_output_row_squeeze(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert not ab, "ab option not allowed."
self._check_row_(row, True, verbose=verbose)
self._check_arg_("axes", tuple)
for a in self.kwargs["axes"]:
row[a] = -1
self._check_row_(row, verbose=verbose)
def _compute_output_row_diagonal(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert not ab, "ab option not allowed."
self._check_row_(row, True, verbose=verbose)
self._check_arg_("diag", list)
to_remove = []
for choice, choices in self.kwargs["diag"]:
for ch in choices:
if ch != choice:
to_remove.append(ch)
for i in range(len(row)):
if row[i] in choices:
if row[i] != choice:
row[i] = choice
to_remove.sort()
for r in to_remove:
for i in range(len(row)):
assert (
row[i] != r
), "Unexpected result r=%r row=%r to_remove=%r " "diag=%r." % (
r,
row,
to_remove,
self.kwargs["diag"],
)
if row[i] > r:
row[i] -= 1
self._check_row_(row, verbose=verbose)
def _compute_output_row_matmul(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert ab, "ab must be True."
assert row2 is not None, "row2 must be defined"
self._check_row_(row, True, verbose=verbose)
self._check_row_(row2, True, verbose=verbose)
self._check_arg_("axes", tuple)
self._check_arg_("left", tuple)
self._check_arg_("right", tuple)
self._check_arg_("ndim", int)
assert row2 is not None, "matmul expects two inputs."
if verbose:
ndim = self.kwargs["ndim"]
axes = self.kwargs["axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
print(
" MATMUL %r @ %r axes=%r left=%r right=%r - eq=%s"
% (
row,
row2,
axes,
left,
right,
_numpy_extended_dot_equation(ndim, ndim, axes, left, right),
)
)
row2[:] = numpy.maximum(row, row2)
for a in self.kwargs["axes"]:
if a not in self.kwargs["right"]:
row2[a] = -1
self._check_row_(row2, verbose=verbose)
def _compute_output_row_batch_dot(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert ab, "ab must be True."
self._check_row_(row, True, verbose=verbose)
self._check_row_(row2, True, verbose=verbose)
self._check_arg_("batch_axes", tuple)
self._check_arg_("keep_axes", tuple, empty=True)
self._check_arg_("sum_axes", tuple)
self._check_arg_("left", tuple)
self._check_arg_("right", tuple)
self._check_arg_("ndim", int)
assert row2 is not None, "batch_dot expects two inputs."
if verbose:
batch_axes = self.kwargs["batch_axes"]
keep_axes = self.kwargs["keep_axes"]
sum_axes = self.kwargs["sum_axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
ndim = self.kwargs["ndim"]
print(
" BATCH_DOT batch_axes=%r keep_axes=%r sum_axes=%r "
"left=%r right=%r eq=%r"
% (
batch_axes,
keep_axes,
sum_axes,
left,
right,
_numpy_extended_dot_equation(ndim, ndim, sum_axes, left, right),
)
)
row2[:] = numpy.maximum(row, row2)
for a in self.kwargs["sum_axes"]:
if a not in self.kwargs["right"]:
row2[a] = -1
self._check_row_(row2, verbose=verbose)
def _compute_output_row_mul(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
assert ab, "ab must be True."
self._check_row_(row, True, verbose=verbose)
self._check_row_(row2, True, verbose=verbose)
assert row2 is not None, "mul expects two inputs."
if verbose:
print(f" MUL {row!r} @ {row2!r}")
row2[:] = numpy.maximum(row, row2)
self._check_row_(row2, verbose=verbose)
[docs]
def compute_output_row(
self,
row: numpy.ndarray,
row2: Optional[numpy.ndarray] = None,
ab: bool = False,
verbose: bool = False,
):
"""
Updates *row* based on the operator.
"""
method_name = f"_compute_output_row_{self.name}"
meth = getattr(self, method_name, None)
if meth is None:
raise NotImplementedError(
f"compute_output_row not implemented for {self.name!r}."
)
if verbose and ab:
print(" -- called as a binary operator")
self.add_info(i_row=single_axes(row), i_row2=single_axes(row2))
meth(row, row2=row2, ab=ab, verbose=verbose)
self.add_info(o_row=single_axes(row), o_row2=single_axes(row2))
[docs]
def add_info(self, **kwargs: Dict[str, Any]):
"""
Adds information to the node.
:param kwargs: dictionary
"""
for k, v in kwargs.items():
assert (
k not in self._info
), f"Key {k!r} already added (operator {self.name!r})."
self._info[k] = v
def _check_inputs_(self, n_expected: int, check_dim: bool = False):
assert (
len(self.inputs) == n_expected
), "Number of inputs must be %d not %d for operator %r." "" % (
n_expected,
len(self.inputs),
self.name,
)
def _check_shape_(self, m: numpy.ndarray):
assert (
len(m.shape) == self.full_dim
), "Number of dimensions %r is different from expected value " "%d." % (
m.shape,
self.full_dim,
)
def _get_data(self, data: Dict[int, Any], key: Union[int, "EinsumSubOp"]) -> Any:
if isinstance(key, int):
assert key in data, "Unable to find key %d in %r." % (
key,
list(sorted(data)),
)
return data[key]
if isinstance(key, EinsumSubOp):
assert id(key) in data, "Unable to find key %d in %r." % (
id(key),
list(sorted(data)),
)
return data[id(key)]
raise TypeError(f"Unexpected input type {type(key)!r}.")
def _apply_id(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(1)
inp = self.inputs[0]
output = self._get_data(data, inp)
return output
def _apply_diagonal(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(1)
inp = self.inputs[0]
m = self._get_data(data, inp)
if verbose:
print(f"- {self.name}, shape={m.shape!r} diag={self.kwargs['diag']!r}")
diag = self.kwargs["diag"]
if len(diag) != 1:
raise NotImplementedError(
f"Not implemented with more than one duplicated indice {diag!r}."
)
diag0 = diag[0]
output = numpy_diagonal(m, axis=diag0[0], axes=diag0[1])
return output
def _apply_expand_dims(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(1)
inp = self.inputs[0]
m = self._get_data(data, inp)
if verbose:
print(f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
output = m
for axis in reversed(self.kwargs["axes"]):
output = numpy.expand_dims(output, axis[0])
return output
def _apply_transpose(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(1, True)
inp = self.inputs[0]
m = self._get_data(data, inp)
self._check_shape_(m)
if verbose:
print(f"- {self.name}, shape={m.shape!r} perm={self.kwargs['perm']!r}")
output = numpy.transpose(m, self.kwargs["perm"])
self._check_shape_(output)
return output
def _apply_transpose_mm(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(2, True)
inp = self.inputs[0]
m = self._get_data(data, inp)
self._check_shape_(m)
if verbose:
print(f"- {self.name}, shape={m.shape!r} perm={self.kwargs['perm']!r}")
output = numpy.transpose(m, self.kwargs["perm"])
self._check_shape_(output)
return output
def _apply_matmul(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(2)
inp1 = self.inputs[0]
inp2 = self.inputs[1]
m1 = self._get_data(data, inp1)
m2 = self._get_data(data, inp2)
self._check_shape_(m1)
self._check_shape_(m2)
axes = self.kwargs["axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
if verbose:
print(
"- %s, shapes=%r @ %r axes=%r left=%r right=%r"
% (self.name, m1.shape, m2.shape, axes, left, right)
)
impl = kwargs.get("matmul_impl", None)
if impl == "pyf":
output = numpy_extended_dot_matrix(
m1, m2, axes, left, right, verbose=verbose
)
elif impl == "py":
output = numpy_extended_dot_python(
m1, m2, axes, left, right, verbose=verbose
)
elif impl is None:
output = numpy_extended_dot(m1, m2, axes, left, right, verbose=verbose)
else:
raise ValueError(f"Unknown implementation of numpy_extended_dot ({impl}).")
self._check_shape_(output)
return output
def _apply_mul(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(2)
inp1 = self.inputs[0]
inp2 = self.inputs[1]
m1 = self._get_data(data, inp1)
m2 = self._get_data(data, inp2)
self._check_shape_(m1)
self._check_shape_(m2)
if verbose:
print(f"- {self.name}, shapes={m1.shape!r} @ {m2.shape!r}")
output = m1 * m2
self._check_shape_(output)
return output
def _apply_batch_dot(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(2)
inp1 = self.inputs[0]
inp2 = self.inputs[1]
m1 = self._get_data(data, inp1)
m2 = self._get_data(data, inp2)
self._check_shape_(m1)
self._check_shape_(m2)
batch_axes = self.kwargs["batch_axes"]
keep_axes = self.kwargs["keep_axes"]
sum_axes = self.kwargs["sum_axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
if verbose:
print(
"- %s, shapes=%r @ %r batch_axes=%r keep_axes=%r "
"sum_axes=%r"
% (self.name, m1.shape, m2.shape, batch_axes, keep_axes, sum_axes)
)
assert len(m1.shape) == len(m2.shape), (
"batch_dot only work with two tensors with the same number "
"of dimensions not %r @ %r." % (m1.shape, m2.shape)
)
dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes]))
dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes]))
dimb = int(
-1 if keep_axes is None else numpy.prod([m1.shape[i] for i in keep_axes])
)
dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes]))
dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes]))
if verbose:
print(f"- {self.name}, reshape={m1.shape!r} into {dim0, dimb, dim1!r}")
print(f"- {self.name}, reshape={m2.shape!r} into {dim0b, dimb, dim2!r}")
m1sh = m1.reshape((dim0, dimb, dim1))
m2sh = m2.reshape((dim0b, dimb, dim2))
batch_kind = self.get_dot_kind()
if batch_kind in ("11", "N1", "N1"):
m1sh = m1sh.reshape((-1, m1sh.shape[-1]))
m2sh = m2sh.reshape((-1, m2sh.shape[-1]))
if verbose:
print(
"- %s, use gemm with shape %r, %r"
% (self.name, m1sh.shape, m2sh.shape)
)
dot = gemm_dot(m1sh, m2sh, False, True)
else:
dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1))
# new shape
new_shape = (
[max(m1.shape[i], m2.shape[i]) for i in batch_axes]
+ [m1.shape[i] for i in left if i not in batch_axes]
+ [m2.shape[i] for i in right if i not in batch_axes]
)
while len(new_shape) < len(m1.shape):
new_shape.append(1)
if verbose:
taken = set(batch_axes) | set(sum_axes)
ax = [i for i in range(len(m1.shape)) if i not in taken]
print(
"- %s, shapes=%r @ %r -> %r"
% (self.name, m1sh.shape, m2sh.shape, dot.shape)
)
print(
"- %s, batch_axes=%r ax=%r new_shape=%r left=%r right=%r"
% (self.name, batch_axes, ax, new_shape, left, right)
)
output = dot.reshape(tuple(new_shape))
self._check_shape_(output)
return output
def _apply_reduce_sum(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(1)
inp = self.inputs[0]
m = self._get_data(data, inp)
self._check_shape_(m)
axes = self.kwargs["axes"]
if verbose:
print(f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
output = numpy.sum(m, axis=axes, keepdims=True)
self._check_shape_(output)
return output
def _apply_reduce_sum_mm(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(2, True)
inp = self.inputs[0]
m = self._get_data(data, inp)
self._check_shape_(m)
if verbose:
print(f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
output = numpy.sum(m, self.kwargs["axes"])
self._check_shape_(output)
return output
def _apply_squeeze(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
self._check_inputs_(1)
inp = self.inputs[0]
m = self._get_data(data, inp)
axes = self.kwargs["axes"]
if verbose:
print(f"- {self.name}, shape={m.shape!r} axes={self.kwargs['axes']!r}")
output = m
for a in axes[::-1]:
output = numpy.squeeze(output, axis=a)
return output
[docs]
def apply(
self, data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Any:
"""
Applies one operator on the data.
:param data: dictionary storing the results
:param verbose: prints out intermediate results
:param kwargs: additional parameters, see
methods `_apply*`
:return: output
Known additional paramaters:
* 'matmul_impl': if None calls :func:`numpy.einsum` through
:func:`numpy_extended_dot
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot>`
(default) or 'py' to call
:func:`numpy_extended_dot_python
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_python>`
instead.
"""
if verbose:
print()
print(
"apply %r (%s)."
% (self.name, ", ".join(map(lambda s: str(id(s)), self.inputs)))
)
method_name = f"_apply_{self.name}"
meth = getattr(self, method_name, None)
if meth is None:
raise NotImplementedError(f"apply not implemented for {self.name!r}.")
output = meth(data, verbose, **kwargs)
data[id(self)] = output
if verbose:
print("+ %s, shape=%r -- %d" % (self.name, output.shape, id(self)))
return output
def _onnx_name(self) -> str:
return "einsum%d_%s" % (id(self), self.name[:2])
def _check_onnx_opset_(self, opset: Optional[int], limit: int):
if opset is not None and opset < limit:
raise RuntimeError(
f"Opset ({opset!r}) must be >= {limit!r} for operator {self.name!r}."
)
def _to_onnx_id(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
self._check_inputs_(1)
inp = self.inputs[0]
name = self._get_data(names, inp)
yield helper.make_node("Identity", [name], [self._onnx_name()])
def _to_onnx_expand_dims(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
self._check_inputs_(1)
self._check_onnx_opset_(opset, 11)
inp = self.inputs[0]
name = self._get_data(names, inp)
axes = self.kwargs["axes"]
name_axes = name + "_axes"
yield numpy_helper.from_array(
numpy.array([a[1] for a in axes], dtype=numpy.int64), name=name_axes
)
s_axes = "".join(map(str, [a[1] for a in axes]))
yield helper.make_node(
"Unsqueeze",
[name, name_axes],
[self._onnx_name()],
name="Unsqueeze%s_%d" % (s_axes, id(self)),
)
def _to_onnx_squeeze(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
self._check_inputs_(1)
self._check_onnx_opset_(opset, 11)
inp = self.inputs[0]
name = self._get_data(names, inp)
axes = self.kwargs["axes"]
name_axes = name + "_axes"
yield numpy_helper.from_array(
numpy.array(axes, dtype=numpy.int64), name=name_axes
)
s_axes = "".join(map(str, axes))
yield helper.make_node(
"Squeeze",
[name, name_axes],
[self._onnx_name()],
name="Squeeze%s_%d" % (s_axes, id(self)),
)
def _to_onnx_transpose(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
self._check_inputs_(1)
inp = self.inputs[0]
name = self._get_data(names, inp)
perm = self.kwargs["perm"]
s_perm = "".join(map(str, perm))
yield helper.make_node(
"Transpose",
[name],
[self._onnx_name()],
perm=perm,
name="Transpose%s_%d" % (s_perm, id(self)),
)
def _to_onnx_reduce_sum(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
self._check_inputs_(1)
self._check_onnx_opset_(opset, 11)
inp = self.inputs[0]
name = self._get_data(names, inp)
axes = self.kwargs["axes"]
name_axes = self._onnx_name() + "_axes"
yield numpy_helper.from_array(
numpy.array(axes, dtype=numpy.int64), name=name_axes
)
s_axes = "".join(map(str, axes))
yield helper.make_node(
"ReduceSum",
[name, name_axes],
[self._onnx_name()],
keepdims=1,
name="ReduceSum%s_%d" % (s_axes, id(self)),
)
def _to_onnx_mul(
self, data: List[Any], verbose: bool = False, **kwargs: Dict[str, Any]
) -> Iterable[NodeProto]:
self._check_inputs_(2)
inp1 = self.inputs[0]
inp2 = self.inputs[1]
m1 = self._get_data(data, inp1)
m2 = self._get_data(data, inp2)
yield helper.make_node("Mul", [m1, m2], [self._onnx_name()])
def _to_onnx_batch_dot(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
self._check_inputs_(2)
self._check_onnx_opset_(opset, 13)
inp1, inp2 = self.inputs[:2]
name1 = self._get_data(names, inp1)
name2 = self._get_data(names, inp2)
batch_axes = self.kwargs["batch_axes"]
keep_axes = self.kwargs["keep_axes"]
sum_axes = self.kwargs["sum_axes"]
left = self.kwargs["left"]
right = self.kwargs["right"]
root = self._onnx_name()
def return_name_one():
name_one = root + "_1"
return name_one, numpy_helper.from_array(
numpy.array([1], dtype=numpy.int64), name=name_one
)
name_one = None
name_shape1 = root + "_shape1"
name_shape2 = root + "_shape2"
concat_left = []
concat_right = []
yield helper.make_node("Shape", [name1], [name_shape1])
yield helper.make_node("Shape", [name2], [name_shape2])
if len(batch_axes) > 0:
name_batch_axes = root + "_batch_axes"
yield numpy_helper.from_array(
numpy.array(batch_axes, dtype=numpy.int64), name=name_batch_axes
)
if len(sum_axes) > 0:
name_sum_axes = root + "_sum_axes"
yield numpy_helper.from_array(
numpy.array(sum_axes, dtype=numpy.int64), name=name_sum_axes
)
# dim0 = int(numpy.prod([m1.shape[i] for i in batch_axes]))
# dim0b = int(numpy.prod([m2.shape[i] for i in batch_axes]))
if len(batch_axes) > 1:
name_dim0 = root + "_dim0"
name_dim0b = root + "_dim0b"
name_dim0g = name_dim0 + "g"
name_dim0bg = name_dim0b + "g"
concat_left.append(name_dim0)
concat_right.append(name_dim0b)
yield helper.make_node(
"Gather", [name_shape1, name_batch_axes], [name_dim0g]
)
yield helper.make_node(
"Gather", [name_shape2, name_batch_axes], [name_dim0bg]
)
yield helper.make_node("ReduceProd", [name_dim0g], [name_dim0], keepdims=1)
yield helper.make_node(
"ReduceProd", [name_dim0bg], [name_dim0b], keepdims=1
)
elif len(batch_axes) == 1:
name_dim0g = root + "_dim0g"
name_dim0bg = root + "_dim0bg"
name_dim0 = name_dim0g
name_dim0b = name_dim0bg
concat_left.append(name_dim0)
concat_right.append(name_dim0b)
yield helper.make_node(
"Gather", [name_shape1, name_batch_axes], [name_dim0g]
)
yield helper.make_node(
"Gather", [name_shape2, name_batch_axes], [name_dim0bg]
)
else:
if name_one is None:
name_one, cst_init = return_name_one()
yield cst_init
name_dim0 = name_one
name_dim0b = name_one
concat_left.append(name_dim0)
concat_right.append(name_dim0b)
# dimb = int(-1 if keep_axes is None else numpy.prod(
# [m1.shape[i] for i in keep_axes]))
if keep_axes in (-1, None) or len(keep_axes) == 0:
name_dimb = root + "__1"
concat_left.append(name_dimb)
concat_right.append(name_dimb)
yield numpy_helper.from_array(
numpy.array([-1], dtype=numpy.int64), name=name_dimb
)
elif len(keep_axes) == 1:
name_keep_axes = root + "_keep_axes"
name_dimb = root + "_dimb"
name_dimbg = name_dimb
concat_left.append(name_dimb)
concat_right.append(name_dimb)
yield numpy_helper.from_array(
numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes
)
yield helper.make_node(
"Gather", [name_shape1, name_keep_axes], [name_dimbg]
)
else:
name_keep_axes = root + "_keep_axes"
name_dimb = root + "_dimb"
name_dimbg = name_dimb + "g"
concat_left.append(name_dimb)
concat_right.append(name_dimb)
yield numpy_helper.from_array(
numpy.array(keep_axes, dtype=numpy.int64), name=name_keep_axes
)
yield helper.make_node(
"Gather", [name_shape1, name_keep_axes], [name_dimbg]
)
yield helper.make_node("ReduceProd", [name_dimbg], [name_dimb], keepdims=1)
# dim1 = int(numpy.prod([m1.shape[i] for i in sum_axes]))
# dim2 = int(numpy.prod([m2.shape[i] for i in sum_axes]))
if len(sum_axes) == 0:
if name_one is None:
name_one, cst_init = return_name_one()
yield cst_init
name_dim1 = name_one
name_dim2 = name_one
concat_left.append(name_dim1)
concat_right.append(name_dim2)
elif len(sum_axes) == 1:
name_dim1 = root + "_dim1"
name_dim2 = root + "_dim2"
name_dim1g = name_dim1
name_dim2g = name_dim2
concat_left.append(name_dim1)
concat_right.append(name_dim2)
yield helper.make_node("Gather", [name_shape1, name_sum_axes], [name_dim1g])
yield helper.make_node("Gather", [name_shape2, name_sum_axes], [name_dim2g])
else:
name_dim1 = root + "_dim1"
name_dim2 = root + "_dim2"
name_dim1g = name_dim1 + "g"
name_dim2g = name_dim2 + "g"
concat_left.append(name_dim1)
concat_right.append(name_dim2)
yield helper.make_node("Gather", [name_shape1, name_sum_axes], [name_dim1g])
yield helper.make_node("Gather", [name_shape2, name_sum_axes], [name_dim2g])
yield helper.make_node("ReduceProd", [name_dim1g], [name_dim1], keepdims=1)
yield helper.make_node("ReduceProd", [name_dim2g], [name_dim2], keepdims=1)
batch_kind = self.get_dot_kind()
if batch_kind in ("11", "N1", "N1"):
# *shape1, *shape2
name_minus_one = root + "__01"
yield numpy_helper.from_array(
numpy.array([-1], dtype=numpy.int64), name=name_minus_one
)
name_agg_shape1_2 = root + f"_resh1_{batch_kind}"
name_agg_shape2_2 = root + f"_resh2_{batch_kind}"
yield helper.make_node(
"Concat", [name_minus_one, name_dim1], [name_agg_shape1_2], axis=0
)
yield helper.make_node(
"Concat", [name_minus_one, name_dim2], [name_agg_shape2_2], axis=0
)
# m1sh = m1.reshape((-1, dim1))
# m2sh = m2.reshape((-1, dim2))
name_agg1_2 = root + "_aresh1"
name_agg2_2 = root + "_aresh2"
yield helper.make_node("Reshape", [name1, name_agg_shape1_2], [name_agg1_2])
yield helper.make_node("Reshape", [name2, name_agg_shape2_2], [name_agg2_2])
# dot = gemm(m1sh, m2sh, False, True)
name_dot = root + "_gemm"
yield helper.make_node(
"Gemm",
[name_agg1_2, name_agg2_2],
[name_dot],
alpha=1.0,
beta=0.0,
transA=0,
transB=1,
)
else:
# *shape1, *shape2
name_agg_shape1 = root + "_resh1"
name_agg_shape2 = root + "_resh2"
yield helper.make_node("Concat", concat_left, [name_agg_shape1], axis=0)
yield helper.make_node("Concat", concat_right, [name_agg_shape2], axis=0)
# m1sh = m1.reshape((dim0, dimb, dim1))
# m2sh = m2.reshape((dim0b, dimb, dim2))
name_agg1 = root + "_aresh1"
name_agg2 = root + "_aresh2"
yield helper.make_node("Reshape", [name1, name_agg_shape1], [name_agg1])
yield helper.make_node("Reshape", [name2, name_agg_shape2], [name_agg2])
# dot = m1sh @ numpy.transpose(m2sh, (0, 2, 1))
name_agg2_tr = root + "_aresh2_tr"
yield helper.make_node(
"Transpose",
[name_agg2],
[name_agg2_tr],
perm=[0, 2, 1],
name=f"Transpose021_{id(self)}",
)
name_dot = root + "_dot"
yield helper.make_node("MatMul", [name_agg1, name_agg2_tr], [name_dot])
# new_shape = ([max(m1.shape[i], m2.shape[i]) for i in batch_axes] +
# [m1.shape[i] for i in left if i not in batch_axes] +
# [m2.shape[i] for i in right if i not in batch_axes])
concat_final = []
if len(batch_axes) > 0:
name_max_dim = root + "_max_dim"
concat_final.append(name_max_dim)
yield helper.make_node("Max", [name_dim0g, name_dim0bg], [name_max_dim])
left_set = list(sorted(set(left) - (set(batch_axes) & set(left))))
if len(left_set) > 0:
name_left_dim = root + "_left_dim"
name_left_set = root + "_left_set"
yield numpy_helper.from_array(
numpy.array(left_set, dtype=numpy.int64), name=name_left_set
)
yield helper.make_node(
"Gather", [name_shape1, name_left_set], [name_left_dim]
)
concat_final.append(name_left_dim)
right_set = list(sorted(set(right) - (set(batch_axes) & set(right))))
if len(right_set) > 0:
name_right_dim = root + "_right_dim"
name_right_set = root + "_right_set"
yield numpy_helper.from_array(
numpy.array(right_set, dtype=numpy.int64), name=name_right_set
)
yield helper.make_node(
"Gather", [name_shape2, name_right_set], [name_right_dim]
)
concat_final.append(name_right_dim)
name_new_shape = root + "_new_shape"
diff = self.full_dim - (len(batch_axes) + len(left_set) + len(right_set))
if diff > 0:
names_ones = root + "_ones"
yield numpy_helper.from_array(
numpy.array([1 for i in range(diff)], dtype=numpy.int64),
name=names_ones,
)
concat_final.append(names_ones)
yield helper.make_node("Concat", concat_final, [name_new_shape], axis=0)
name_final = root + "_final"
yield helper.make_node("Reshape", [name_dot, name_new_shape], [name_final])
[docs]
def to_onnx(
self,
names: List[str],
opset: Optional[int],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Iterable[NodeProto]:
"""
Converts this node into ONNX. Enumerates all ONNX node
which participate to the conversion. The last one
is the final output.
:param names: dictionary where to find already converted name
:param opset: opset
:param verbose: prints out intermediate results
:param kwargs: additional parameter for the conversion
:return: output
"""
if opset is None:
opset = DEFAULT_OPSET
if verbose:
print()
print(
"to_onnx %r (%s) opset=%r."
% (self.name, ", ".join(map(lambda s: str(id(s)), self.inputs)), opset)
)
method_name = f"_to_onnx_{self.name}"
meth = getattr(self, method_name, None)
if meth is None:
if self.name.endswith("_mm"):
raise NotImplementedError(
"to_onnx not implemented for %r."
"You should call method simplify_mm_nodes "
"to remove it." % self.name
)
raise NotImplementedError(f"to_onnx not implemented for {self.name!r}.")
for node in meth(names, verbose=verbose, opset=opset, **kwargs):
if hasattr(node, "output"):
names[id(self)] = node.output[0]
if verbose:
print(
"+ OP %r -- (%s - %d)" % (node.output[0], self.name, id(self))
)
elif verbose:
# Initializer
print("+ CT %r -- (%s - %d)" % (node.name, self.name, id(self)))
yield node
[docs]
def get_dot_kind(self) -> str:
"""
Every matrix multiplication can be either:
* a simple multiplication (`M`) (undetected)
* a 2D matrix multiplication (`11`)
* a broadcasted matrix multiplication (`N1` or `1N`)
* a batch matrix multiplication (`NN`)
This method returns which kind it is.
"""
batch_axes = self.kwargs["batch_axes"]
# keep_axes = self.kwargs['keep_axes']
# sum_axes = self.kwargs['sum_axes']
# left = self.kwargs['left']
# right = self.kwargs['right']
info = self._info
row_left = info["i_row"]
row_right = info["i_row2"]
batch_left = [row_left[k] for k in batch_axes]
batch_right = [row_right[k] for k in batch_axes]
n_left = len(batch_left) > 0 and max(batch_left) == 2
n_right = len(batch_right) > 0 and max(batch_right) == 2
return f"{'N' if n_left else '1'}{'N' if n_right else '1'}"
[docs]
class GraphEinsumSubOp:
"""
Class gathering all nodes produced to explicit einsum
operators.
:param letters: list of distinct letters
:param mat: matrix, see :func:`analyse_einsum_equation
<onnx_extended.tools.einsum.einsum_impl.analyse_einsum_equation>`
:param lengths: lengths of every input
:param duplicates: see :func:`analyse_einsum_equation
<onnx_extended.tools.einsum.einsum_impl.analyse_einsum_equation>`
"""
def __init__(
self,
letters: str,
mat: numpy.ndarray,
lengths: List[int],
duplicates: List[Dict[str, int]],
):
self._nodes: Dict[int, Union[int, EinsumSubOp]] = {}
self._mark: Dict[int, Union[int, EinsumSubOp]] = {}
self._ops: List[EinsumSubOp] = []
self._inputs: Dict[int, Union[int, EinsumSubOp]] = {}
self.last_op: Optional[Union[int, EinsumSubOp]] = None
self.last_added_op: Optional[Union[int, EinsumSubOp]] = None
self.metadata = dict(
letters=letters,
mat=mat,
lengths=lengths,
mat0=mat.copy(),
duplicates=duplicates,
)
[docs]
def append(self, op: Union[int, EinsumSubOp]) -> Optional[EinsumSubOp]:
"""
Adds one input or result.
:param op: integer (an input) or an instance of :class:`EinsumSubOp
<onnx_extended.tools.einsum.einsum_impl_classes.EinsumSubOp>`.
:return: op or None if op is an integer
"""
if isinstance(op, int):
assert op not in self._nodes, "Key %d already added." % op
self._nodes[op] = op
self.last_added_op = op
self._inputs[op] = op
return None
if isinstance(op, EinsumSubOp):
assert op not in self._nodes, "Key %d already added, op=%r." % (id(op), op)
self._nodes[id(op)] = op
self._ops.append(op)
self.last_added_op = op
return op
raise TypeError(f"Unexpected type {type(op)!r}.")
[docs]
def mark_last_node(self):
"""
Marks the last node as the final output.
"""
assert self.last_added_op is not None, "last_added_op is None."
self.mark(-1, self.last_added_op)
[docs]
def mark(self, i: int, op: EinsumSubOp):
"""
Marks one input or result as an intermediate result
after a full einsum step.
:param i: a position
:param op: an instance of :class:`EinsumSubOp
<onnx_extended.tools.einsum.einsum_impl_classes.EinsumSubOp>`.
"""
assert isinstance(i, int), f"i must an integer not {type(i)!r}."
if i != -1 and i not in self._inputs:
raise RuntimeError("Input %d was not registered in %r." % (i, self._inputs))
if isinstance(op, EinsumSubOp):
assert id(op) in self._nodes, "Key %d not found, op=%r." % (id(op), op)
self._mark[i] = op
self._mark[id(op)] = i
self.last_op = op
else:
raise TypeError(f"Unexpected type {type(i)!r}.")
def __iter__(self) -> Iterable[EinsumSubOp]:
"Iterates on nodes."
for op in self._ops:
yield op
[docs]
def to_dot(self, **kwargs: Dict[str, Any]) -> str:
"""
Produces a graph in :epkg:`dot`.
:param kwargs: additional graph option
:return: string
"""
options = {
"orientation": "portrait",
"ranksep": "0.25",
"nodesep": "0.05",
"width": "0.5",
"height": "0.1",
"size": "5",
"node": "[shape=record]",
}
options.update(kwargs)
def d2s(d):
it = []
for k, v in sorted(d.items()):
it.append(f"{k}={v}")
return " ".join(it)
def d2sd(d):
it = []
for k, v in sorted(d.items()):
if len(v) > 1:
it.append(f"{k}={','.join(map(str, v))}")
return " ".join(it)
rows = ["digraph{"]
for k, v in options.items():
if isinstance(v, str) and "[" in v:
rows.append(f"{k} {v};")
else:
rows.append(f"{k}={v};")
for k, v in self._nodes.items():
if isinstance(v, int):
let = [
(r, self.metadata["letters"][i])
for i, r in enumerate(self.metadata["mat0"][v])
if r != -1
]
dup = self.metadata["duplicates"][v]
if dup is None:
dup = ""
else:
dup = f" - {d2sd(dup)}"
let.sort()
letters = "".join(_[1] for _ in let)
lab = "input %d\\\\n%s\\\\n%s%s" % (
v,
letters,
str(self.metadata["mat0"][v]),
dup,
)
sk = v
extended_lab = ""
else:
lab = f"{v.name}\\\\n{d2s(v.kwargs)}"
sk = id(v)
extended_lab = v.dot_label()
if extended_lab:
extended_lab = "\\\\n" + extended_lab
if sk in self._mark and isinstance(self._mark[sk], int):
la = self._mark[sk]
lab = lab.replace("\\\\n", " - I%d\\\\n" % la)
s = f'{k} [label="{lab}{extended_lab}" style=filled fillcolor=red];'
else:
s = f'{k} [label="{lab}{extended_lab}"];'
rows.append(s)
if not hasattr(v, "inputs"):
continue
for i in v.inputs:
vid = i if isinstance(i, int) else id(i)
s = "%d -> %d;" % (vid, k)
rows.append(s)
rows.append("}")
return "\n".join(rows)
[docs]
def apply_sequence(
self,
*inputs: List[EinsumSubOp],
verbose: bool = False,
**kwargs: Dict[str, Any],
) -> Any:
"""
Applies a sequence of operations on a list of inputs.
:param inputs: inputs
:param verbose: prints out intermediate results
:param kwargs: additional parameters,
see :meth:`apply
<onnx_extended.tools.einsum.einsum_impl_classes.EinsumSubOp.apply>`.
:return: output
"""
if verbose:
print("######### apply_sequence")
data = {i: inp for i, inp in enumerate(inputs)}
last = None
for op in self:
last = op.apply(data, verbose=verbose, **kwargs)
assert last is not None, "Sequence of operations is empty."
return last
[docs]
def clean_unused_nodes(self, verbose: bool = False):
"""
Cleans nodes with unused outputs.
:param verbose: display intermediate information
"""
def iteration(it):
# Walks through all nodes.
is_used = {}
for node in self._ops:
if not isinstance(node, EinsumSubOp):
continue
if id(node) not in is_used:
is_used[id(node)] = []
for inp in node.inputs:
if not isinstance(inp, EinsumSubOp):
continue
idn = id(inp)
if idn not in is_used:
is_used[idn] = []
is_used[idn].append(id(node))
# Remove unused nodes.
removed = []
for k, v in is_used.items():
if len(v) == 0:
removed.append(k)
removed = set(removed)
i_rem = []
for i, op in enumerate(self._ops):
if not isinstance(op, EinsumSubOp):
continue
if id(op) in removed and id(op) not in self._mark:
i_rem.append((i, id(op)))
for i, idn in reversed(i_rem):
if verbose:
print(
"[GraphEinsumSubOp.clean_nodes] remove node "
"i=%d: %d - id=%d" % (it, i, idn)
)
del self._ops[i]
del self._nodes[idn]
return len(i_rem) > 0
it = 1
while iteration(it):
it += 1
self.last_op = None
self.last_added_op = None
[docs]
def simplify_mm_nodes(self, verbose: bool = False):
"""
Node name suffixed by `mm` are an artifact to keep
the graph consistent while building it. They can
now be replaced by the equivalent node without suffix `mm`.
:param verbose: display intermediate information
"""
for op in self:
if not isinstance(op, EinsumSubOp):
continue
if op.name.endswith("_mm"):
if verbose:
print(
"[GraphEinsumSubOp.simplify_mm_nodes] node %r"
" - id=%d" % (op.name, id(op))
)
assert (
len(op.inputs) == 2
), "Expecting 2 inputs for node %r not %r id=%r." % (
op.name,
len(op.inputs),
id(op),
)
op.name = op.name[:-3]
op.inputs = op.inputs[:1]
def _get_forward_nodes(self) -> Dict[int, EinsumSubOp]:
"""
Returns the forward nodes.
"""
forward: Dict[int, EinsumSubOp] = {}
for op in self:
if isinstance(op, int):
continue
for inp in op.inputs:
key = inp if isinstance(inp, int) else id(inp)
if key in forward:
forward[key].append(op)
else:
forward[key] = [op]
return forward
def _pprint_forward(self) -> str:
rows = []
for op in self:
line = "%r <- %s(%s)" % (
id(op),
op.name,
", ".join(map(str, [id(_) for _ in op.inputs])),
)
rows.append(line)
return "\n".join(rows)
def _replace_node_sequence(
self, added: List[EinsumSubOp], deleted: List[EinsumSubOp]
):
"""
Removes a sequence of nodes. The method does not check
that the graph remains consistent.
"""
forward = self._get_forward_nodes()
key = id(deleted[-1])
assert key in forward, (
"Key {} missing in all forward nodes (other keys {}), "
"all keys:\n{}".format(
key, [id(_) for _ in deleted], self._pprint_forward()
)
)
# deletion
mark_input = None
for d in deleted:
del self._nodes[id(d)]
if id(d) in self._mark:
del self._mark[id(d)]
dels = []
for k, v in self._mark.items():
if id(v) == id(d):
mark_input = k
dels.append(k)
assert (
len(dels) == 1
), "Input %d has more than one marked operator " "(%r)." % (id(d), dels)
del self._mark[dels[0]]
dels = set(id(o) for o in deleted)
rem = []
for i, op in enumerate(self._ops):
if id(op) in dels:
rem.append(i)
assert len(rem) == len(
deleted
), f"Mismatched length {rem!r}, {dels!r}, len={len(deleted)!r}."
for i in reversed(rem):
del self._ops[i]
self.last_add_op = None
# insertion
if added is not None:
self._ops.insert(rem[0], added)
self._nodes[id(added)] = added
for op in forward[key]:
new_inputs = list(op.inputs)
for i in range(len(op.inputs)):
if id(op.inputs[i]) == key:
new_inputs[i] = added
op.inputs = tuple(new_inputs)
if mark_input is not None:
self.mark(mark_input, added)
else:
inps = deleted[0].inputs
assert len(inps) == 1, "More than one input. Call another method."
inp = inps[0]
for op in forward[key]:
new_inputs = list(op.inputs)
for i in range(len(op.inputs)):
if id(op.inputs[i]) == key:
new_inputs[i] = inp
op.inputs = tuple(new_inputs)
if mark_input is not None:
self.mark(mark_input, inp)
[docs]
def remove_duplicate_transpose(self, verbose: bool = False):
"""
Removes consecutive transpose by merging them.
:param verbose: display intermediate information
"""
modif = 1
while modif > 0:
modif = 0
candidates = []
forward = self._get_forward_nodes()
for op in self:
if op.name == "transpose":
inp = op.inputs[0]
if (
isinstance(inp, EinsumSubOp)
and inp.name == "transpose"
and len(forward[id(inp)]) == 1
):
candidates.append(op)
if len(candidates) > 0:
modif = 1
# Not efficient to take the first one and to
# start again but the graph should not be too big.
cand = candidates[0]
op2 = cand
op1 = cand.inputs[0]
perm1 = op1.kwargs["perm"]
perm2 = op2.kwargs["perm"]
assert len(perm1) == len(
perm2
), "Transposition should have the same length " "%r, %r." % (
perm1,
perm2,
)
perm = list(perm1)
for i in range(len(perm)):
perm[i] = perm1[perm2[i]]
if list(range(len(perm))) == perm:
# identity, everything needs to be removed
new_op = None
else:
new_op = op2.__class__(
op2.full_dim, op2.name, op1.inputs[0], perm=tuple(perm)
)
self._replace_node_sequence(new_op, [op1, op2])
if verbose:
print(
"[GraphEinsumSubOp.remove_duplicate_transpose] remove nodes %r"
" - id=%d,%d + %d perm1=%r perm2=%r -> perm=%r"
% (
op2.name,
id(op1),
id(op2),
id(new_op) if new_op is not None else -1,
perm1,
perm2,
perm,
)
)
[docs]
def to_onnx(
self,
output: str,
*inputs: List[str],
dtype: Optional[Any] = None,
verbose: bool = False,
opset: Optional[int] = None,
**kwargs: Dict[str, Any],
) -> ModelProto:
"""
Converts the graph into ONNX.
:param output: output name
:param inputs: input names
:param dtype: type used for all operators
:param opset: desired opset, None for the last one
:param verbose: display intermediate operators
:param kwargs: additional parameter to use when building
the ONNX graph, list of supported parameters:
*name*, *ir_version*, *producer_name*,
*producer_version*, *initializer*
:return: ONNX graph
Not all graphs can be converted into ONNX. Only graphs produced
with `strategy='numpy'` can be converted otherwise the following
error shows up:
::
NotImplementedError: to_onnx not implemented for 'matmul'.
"""
from ..onnx_nodes import onnx_remove_node_unused
# inputs
if opset is None:
opset = DEFAULT_OPSET
if verbose:
print(
"[GraphEinsumSubOp.to_onnx] %r -> %s opset=%r "
"dtype=%r" % (inputs, output, opset, dtype)
)
onx_inputs = []
proto = guess_proto_dtype(numpy.float32 if dtype is None else dtype)
lengths = self.metadata["lengths"]
names: Dict[int, str] = {}
for inp, le in zip(inputs, lengths):
if isinstance(inp, tuple):
name, (typ, shape) = inp
assert le == len(
shape
), "Irreconcialable shapes for input %r: " "%r != len(%r)." % (
name,
le,
typ.shape,
)
onx_inputs.append(helper.make_tensor_value_info(name, typ, shape))
names[len(names)] = name
else:
onx_inputs.append(
helper.make_tensor_value_info(inp, proto, [None for i in range(le)])
)
names[len(names)] = inp
# output
onx_output = helper.make_tensor_value_info(
output, proto, [None for i in range(lengths[-1])]
)
# nodes
nodes = []
inits: List[TensorProto] = []
if "initializer" in kwargs:
inits.extend(kwargs["initializer"])
for op in self:
for onx_node in op.to_onnx(names, verbose=verbose, opset=opset):
if hasattr(onx_node, "output"):
nodes.append(onx_node)
else:
inits.append(onx_node)
# last node
last_node = nodes[-1]
nodes.append(helper.make_node("Identity", [last_node.output[0]], [output]))
# Builds the graph
model = helper.make_model(
opset_imports=[helper.make_operatorsetid("", opset)],
ir_version=kwargs.get("ir_version", DEFAULT_IR_VERSION),
producer_name=kwargs.get("producer_name", "onnx_extended"),
producer_version=kwargs.get("producer_version", "0.0.dev"),
graph=helper.make_graph(
name=kwargs.get("name", "einsum"),
inputs=onx_inputs,
outputs=[onx_output],
initializer=inits,
nodes=nodes,
),
)
return onnx_remove_node_unused(model)