from typing import Dict, List, Optional, Tuple, Union
import numpy
def numpy_diagonal(m: numpy.ndarray, axis: int, axes: Tuple[int, ...]) -> numpy.ndarray:
"""
Extracts diagonal coefficients from an array.
:param m: input array
:param axis: kept axis among the diagonal ones
:param axes: diagonal axes (axis must be one of them)
:return: output
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_diagonal
mat = numpy.arange(8).reshape((2, 2, 2))
print(mat)
diag = numpy_diagonal(mat, 1, [1, 2])
print(diag)
"""
if axis not in axes:
raise RuntimeError(f"axis {axis!r} must be in axes {axes!r}.")
shape = []
new_shape = []
for i, s in enumerate(m.shape):
if i in axes:
if i == axis:
shape.append(s)
new_shape.append(s)
else:
shape.append(1)
else:
shape.append(s)
new_shape.append(s)
# Extracts coefficients.
output = numpy.empty(tuple(shape), dtype=m.dtype)
index_in: List[Union[int, slice]] = [slice(s) for s in m.shape]
index_out: List[Union[int, slice]] = [slice(s) for s in m.shape]
for i in range(0, shape[axis]):
for a in axes:
index_in[a] = i
index_out[a] = i if a == axis else 0
output[tuple(index_out)] = m[tuple(index_in)]
# Removes axis.
return output.reshape(tuple(new_shape))
def _numpy_extended_dot_equation(
m1_dim: int,
m2_dim: int,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
) -> str:
"""
Returns the equation equivalent to an extended version
of an aligned matrix multiplication
(see :func:`numpy_extended_dot
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot>`).
:param m1: number of dimensions of the first matrix
:param m2: number of dimensions of the second matrix
:param axes: summation axes
:param axes: summation axes
:param left: left axes
:param right: right axes
:return: equation
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum.einsum_impl_ext import (
numpy_extended_dot_python, _numpy_extended_dot_equation)
a = numpy.arange(6).reshape((3, 2, 1))
b = numpy.arange(12).reshape((3, 1, 4))
print(numpy_extended_dot_python(
a, b, axes=(0, ), left=(1,), right=(2,)))
# Equivalent einsum equation
print('equation', _numpy_extended_dot_equation(
len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
# Same einsum computation written in a different way.
print(numpy.einsum('kix,kxj->xij', a, b))
"""
assert m1_dim == m2_dim, (
"Matrices m1 and m2 must have the same number of dimensions, "
"m1=%r, m2=%r." % (m1_dim, m2_dim)
)
total = set(axes) | set(left) | set(right)
assert len(total) <= m1_dim, (
"Whole set of involved axes should be inferior to the number "
"of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements"
"." % (total, axes, left, right, m1_dim)
)
def _check_(axs, n):
for a in axs:
if a < 0 or a >= n:
raise ValueError(
"One axis %d (in %r) is negative or above the maximum "
"dimension %d." % (a, axs, n)
)
_check_(axes, m1_dim)
_check_(left, m1_dim)
_check_(right, m1_dim)
l1 = [chr(i + 97) for i in range(m1_dim)]
l2 = [chr(i + 97) for i in range(m1_dim)]
l3: List[Optional[str]] = [chr(i + 97) for i in range(m1_dim)]
for a in left:
l1[a] = l1[a].upper()
l3[a] = l3[a].upper()
for a in right:
l2[a] = l2[a].upper()
l3[a] = l3[a].upper()
for a in axes:
l1[a] = l1[a].lower()
l2[a] = l2[a].lower()
if a not in right:
l3[a] = None
else:
l3[a] = l3[a].lower()
eq = f"{''.join(l1)},{''.join(l2)}->{''.join(s for s in l3 if s)}"
return eq
def _common_check_numpy_extended_dot(
m1: numpy.ndarray,
m2: numpy.ndarray,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
):
"""
Common verifications for all implementations of
:func:`numpy_extended_dot
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot>`.
"""
assert (
m1.dtype == m2.dtype
), f"Both matrices should share the same dtype {m1.dtype!r} != {m2.dtype!r}."
m1_dim = len(m1.shape)
m2_dim = len(m2.shape)
assert m1_dim == m2_dim, (
"Matrices m1 and m2 must have the same number of dimensions, "
"m1=%r, m2=%r." % (m1_dim, m2_dim)
)
total = set(axes) | set(left) | set(right)
assert len(total) <= m1_dim, (
"Whole set of involved axes should be inferior to the number "
"of dimensions: %r = {%r} | {%r} | {%r} has more than %d elements"
"." % (total, axes, left, right, m1_dim)
)
[docs]def numpy_extended_dot(
m1: numpy.ndarray,
m2: numpy.ndarray,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
verbose: bool = False,
) -> numpy.ndarray:
"""
Extended version of a matrix multiplication (:func:`numpy.dot`)
with two matrices *m1*, *m2* of the same dimensions.
Loops over *left* axes for *m1* and *right* axes for *m2*,
summation is done over *axes*.
Other axes must be empty.
This multiplication combines matrix multiplication (dot)
and broadcasted multiplication term by term.
:param m1: first matrix
:param m2: second matrix
:param axes: summation axes
:param left: left axes
:param right: right axes
:param verbose: display intermediate information
:return: output
The dot product is equivalent to:
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_extended_dot
m1 = numpy.arange(4).reshape((2, 2))
m2 = m1 + 10
print("dot product")
print(m1 @ m2)
dm1 = m1.reshape((2, 2, 1))
dm2 = m2.reshape((1, 2, 2))
dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2],
verbose=True)
print("extended dot product")
print(dot)
Empty axes should be squeezed to get identical results.
Dot product when the second matrix is transposed.
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_extended_dot
m1 = numpy.arange(4).reshape((2, 2))
m2 = m1 + 10
print("dot product")
print(m1 @ m2.T)
dm1 = m1.reshape((2, 1, 2))
dm2 = m2.reshape((1, 2, 2))
dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1],
verbose=True)
print("extended dot product")
print(dot)
An example when right axes include the summation axis.
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_extended_dot
m1 = numpy.arange(4).reshape((2, 2))
m2 = m1 + 10
dm1 = m1.reshape((2, 2, 1))
dm2 = m2.reshape((1, 2, 2))
dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2],
verbose=True)
print(dot)
Example in higher dimension:
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_extended_dot
m1 = numpy.arange(8).reshape((2, 2, 2))
m2 = m1 + 10
dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True)
print(dot)
The current implementation still uses :func:`numpy.einsum`
but this should be replaced.
"""
_common_check_numpy_extended_dot(m1, m2, axes, left, right)
eq = _numpy_extended_dot_equation(len(m1.shape), len(m2.shape), axes, left, right)
if verbose:
print(f" [numpy_extended_dot] {eq}: {m1.shape!r} @ {m2.shape!r}")
output = numpy.einsum(eq, m1, m2)
new_shape = list(output.shape)
for a in axes:
if a not in right:
new_shape.insert(a, 1)
if verbose:
print(f" [numpy_extended_dot] {output.shape!r} reshaped into {new_shape!r} ")
return output.reshape(tuple(new_shape))
def numpy_extended_dot_ouput_shape(
m1: numpy.ndarray,
m2: numpy.ndarray,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
) -> numpy.ndarray:
"""
Computes the output shape of results produced by function
:func:`numpy_extended_dot
<onnx_extended.tools.einsum_impl_ext.numpy_extended_dot>` or
:func:`numpy_extended_dot_python
<onnx_extended.tools.einsum_impl_ext.numpy_extended_dot_python>`.
"""
_common_check_numpy_extended_dot(m1, m2, axes, left, right)
m1_dim = len(m1.shape)
new_shape = numpy.full(m1_dim, 1, dtype=numpy.int64)
for i in left:
new_shape[i] = m1.shape[i]
for i in right:
if (
i in left
and m1.shape[i] != m2.shape[i]
and m1.shape[i] != 1
and m2.shape[i] != 1
):
raise RuntimeError(
"Matrices should have the same dimension for dimension %d, "
"shapes=%r @ %r." % (i, m1.shape, m2.shape)
)
new_shape[i] = m2.shape[i]
return new_shape
def _numpy_extended_dot_python_l1l2l3(
m1_dim: int, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...]
) -> Tuple[List[str], List[str], List[Optional[str]]]:
l1 = [chr(i + 97) for i in range(m1_dim)]
l2 = [chr(i + 97) for i in range(m1_dim)]
l3: List[Optional[str]] = [chr(i + 97) for i in range(m1_dim)]
for a in left:
l1[a] = l1[a].upper()
l3[a] = l3[a].upper()
for a in right:
l2[a] = l2[a].upper()
l3[a] = l3[a].upper()
for a in axes:
l1[a] = l1[a].lower()
l2[a] = l2[a].lower()
if a not in right:
l3[a] = "-"
else:
l3[a] = l3[a].lower()
return l1, l2, l3
def _numpy_extended_dot_python_intermediate(
m1_shape: Tuple[int, ...],
m2_shape: Tuple[int, ...],
l1: List[str],
l2: List[str],
l3: List[Optional[str]],
) -> Tuple[
List[str], numpy.ndarray, Dict[str, int], List[bool], List[bool], numpy.ndarray
]:
names = list(sorted(set(l1 + l2)))
kind = numpy.zeros(len(names), dtype=numpy.int64)
cols = {}
for i, n in enumerate(names):
if n in l1:
kind[i] += 1
cols[n] = l1.index(n)
if n in l2:
kind[i] += 2
cols[n] = l2.index(n)
if n in l3:
kind[i] += 4
pos = numpy.zeros(len(names), dtype=numpy.int64)
for j in range(0, pos.shape[0]):
pos[j] = cols[names[j]]
common = [(kind[i] & 3) == 3 for i in range(len(kind))]
broadcast = [
common[i] and m1_shape[pos[i]] != m2_shape[pos[i]] for i in range(len(common))
]
return names, kind, cols, common, broadcast, pos
def _numpy_extended_dot_python_update_broadcast(
m1: numpy.ndarray,
m2: numpy.ndarray,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
l1: List[str],
l2: List[str],
l3: List[Optional[str]],
names: List[str],
broadcast: List[bool],
cols: Dict[str, int],
kind: numpy.ndarray,
common: List[bool],
verbose: bool = False,
) -> Tuple[List[str], List[str], List[Optional[str]]]:
def dispb(c):
return "".join("o" if b else "." for b in c)
if verbose:
print(
"[GENERICDOT] before broadcast %s,%s->%s or %s"
% (
"".join(l1),
"".join(l2),
"".join([(s if s else "?") for s in l3]),
_numpy_extended_dot_equation(
len(m1.shape), len(m1.shape), axes, left, right
),
)
)
print(
"[GENERICDOT] names=%s kind=%r common=%s broadcast=%s"
% ("".join(names), kind.tolist(), dispb(common), dispb(broadcast))
)
for i in range(len(broadcast)):
if broadcast[i] and not (kind[i] & 3) == 3:
raise RuntimeError(
"Broadcast should only happen on common axes, "
"axes=%r left=%r right=%r shape1=%r shape2=%r."
"" % (axes, left, right, m1.shape, m2.shape)
)
if not broadcast[i]:
continue
# We split letters.
p = cols[names[i]]
dim = (m1.shape[p], m2.shape[p])
let = [l1[p], l2[p], l3[p]]
inp = 1 if dim[0] == 1 else 0
if verbose:
print(
"[GENERICDOT] name=%s dim=%r let=%r inp=%r p=%r"
% (names[i], dim, let, inp, p)
)
print(f" B0 l1={l1!r}, l2={l2!r} l3={l3!r}")
if (kind[i] & 4) > 0:
# Summation axis is part of the output.
assert (
let[inp] is not None
), f"Unexpected value for let[{inp}] in let={let}."
if let[inp].lower() == let[inp]:
let[inp] = let[inp].upper()
else:
let[inp] = let[inp].lower()
l3[p] = let[inp]
if inp == 1:
l2[p] = let[inp]
else:
l1[p] = let[inp]
if verbose:
print(f" B1 l1={l1!r}, l2={l2!r} l3={l3!r}")
else:
# Summation axis is not part of the output.
assert (
let[inp] is not None
), f"Unexpected value for let[{inp}] in let={let}."
if let[inp].lower() == let[inp]:
let[inp] = let[inp].upper()
else:
let[inp] = let[inp].lower()
if inp == 1:
l2[p] = let[inp]
else:
l1[p] = let[inp]
if verbose:
print(f" B2 l1={l1!r}, l2={l2!r} l3={l3!r}")
return l1, l2, l3
[docs]def numpy_extended_dot_python(
m1: numpy.ndarray,
m2: numpy.ndarray,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
verbose: bool = False,
) -> numpy.ndarray:
"""
Implementation of :func:`numpy_extended_dot
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot>`
in pure python. This implementation is not efficient
but shows how to implement this operation without :func:`numpy.einsum`.
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_extended_dot_python
from onnx_extended.tools.einsum.einsum_impl_ext import (
_numpy_extended_dot_equation)
a = numpy.arange(6).reshape((3, 2, 1))
b = numpy.arange(12).reshape((3, 1, 4))
print(numpy_extended_dot_python(
a, b, axes=(0, ), left=(1,), right=(2,)))
# Equivalent einsum equation
print('equation', _numpy_extended_dot_equation(
len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
# Same einsum computation written in a different way.
print(numpy.einsum('kix,kxj->xij', a, b))
"""
def dispb(c):
return "".join("o" if b else "." for b in c)
new_shape = numpy_extended_dot_ouput_shape(m1, m2, axes, left, right)
m1_dim = len(m1.shape)
# output result
res = numpy.full(tuple(new_shape), 0, dtype=m1.dtype)
# indices
l1, l2, l3 = _numpy_extended_dot_python_l1l2l3(m1_dim, axes, left, right)
names, kind, cols, common, broadcast, pos = _numpy_extended_dot_python_intermediate(
m1.shape, m2.shape, l1, l2, l3
)
if any(broadcast):
l1, l2, l3 = _numpy_extended_dot_python_update_broadcast(
m1,
m2,
axes,
left,
right,
l1,
l2,
l3,
names,
broadcast,
cols,
kind,
common,
verbose=verbose,
)
(
names,
kind,
cols,
common,
broadcast,
pos,
) = _numpy_extended_dot_python_intermediate(m1.shape, m2.shape, l1, l2, l3)
indices = numpy.array([0 for n in names], dtype=numpy.int64)
pl1 = numpy.array([names.index(c) for c in l1], dtype=numpy.int64)
pl2 = numpy.array([names.index(c) for c in l2], dtype=numpy.int64)
limits = numpy.array(
[
m1.shape[pos[n]] if (kind[n] & 1) == 1 else m2.shape[pos[n]]
for n in range(len(names))
],
dtype=numpy.int64,
)
plo = numpy.array(
[-1 if c not in names else names.index(c) for c in l3], dtype=numpy.int64
)
if verbose:
print(
"[GENERICDOT] %s,%s->%s or %s"
% (
"".join(l1),
"".join(l2),
"".join([(s if s else "?") for s in l3]),
_numpy_extended_dot_equation(
len(m1.shape), len(m1.shape), axes, left, right
),
)
)
print(
"[GENERICDOT] shape1=%r shape2=%r shape=%r"
% (m1.shape, m2.shape, res.shape)
)
print(f"[GENERICDOT] axes={axes!r} left={left!r} right={right!r}")
print(f"[GENERICDOT] pl1={pl1!r} pl2={pl2!r} plo={plo!r}")
print(
"[GENERICDOT] names=%s kind=%r common=%s broadcast=%s"
% ("".join(names), kind.tolist(), dispb(common), dispb(broadcast))
)
print(f"[GENERICDOT] pos={pos.tolist()!r}")
print(f"[GENERICDOT] cols={cols!r}")
print(f"[GENERICDOT] limits={limits!r}")
while indices[0] < limits[0]:
# The function spends most of its time is these three lines.
t1 = tuple(indices[n] for n in pl1)
t2 = tuple(indices[n] for n in pl2)
to = tuple(0 if n == -1 else indices[n] for n in plo)
c = m1[t1] * m2[t2]
if verbose:
print(f" {t1!r} x {t2!r} -> {to!r} v={c!r} I={indices!r}")
res[to] += c
last = len(indices) - 1
indices[last] += 1
for i in range(last, 0, -1):
if indices[i] < limits[i]:
break
indices[i] = 0
if i > 0:
indices[i - 1] += 1
return res
[docs]def numpy_extended_dot_matrix(
m1: numpy.ndarray,
m2: numpy.ndarray,
axes: Tuple[int, ...],
left: Tuple[int, ...],
right: Tuple[int, ...],
verbose: bool = False,
) -> numpy.ndarray:
"""
Implementation of :func:`numpy_extended_dot
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot>`
using dot product, multiplication, transpose and reduction
but not a custom python implementation like
:func:`numpy_extended_dot_python
<onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_python>`.
.. runpython::
:showcode:
import numpy
from onnx_extended.tools.einsum import numpy_extended_dot_matrix
from onnx_extended.tools.einsum.einsum_impl_ext import (
_numpy_extended_dot_equation)
a = numpy.arange(6).reshape((3, 2, 1))
b = numpy.arange(12).reshape((3, 1, 4))
print(numpy_extended_dot_matrix(
a, b, axes=(0, ), left=(1,), right=(2,)))
# Equivalent einsum equation
print('equation', _numpy_extended_dot_equation(
len(a.shape), len(a.shape), axes=(0, ), left=(1,), right=(2,)))
# Same einsum computation written in a different way.
print(numpy.einsum('kix,kxj->xij', a, b))
"""
_common_check_numpy_extended_dot(m1, m2, axes, left, right)
if verbose:
print(
"[GENERICDOT] shape1=%r shape2=%r axes=%r "
"left=%r right=%r -- %s"
% (
m1.shape,
m2.shape,
axes,
left,
right,
_numpy_extended_dot_equation(
len(m1.shape), len(m1.shape), axes, left, right
),
)
)
if len(axes) == 0 and not (set(left) & set(right)):
# Simple multiplication
res = m1 * m2
if verbose:
print(f"[GENERICDOT] Mul {m1.shape!r} @ {m2.shape!r} -> {res.shape!r}")
return res
if len(set(axes) & set(left)) == 0 and len(set(axes) & set(right)) == 0:
# No intersection between axes and right: matrix multiplication
# ReduceSum
right_no_left = set(right) - (set(right) & (set(left) | set(axes)))
if right_no_left:
red1 = m1.sum(axis=tuple(sorted(right_no_left)), keepdims=True)
if verbose:
print(
"[GENERICDOT] reducesumL=%r, %r -> %r"
% (right_no_left, m1.shape, red1.shape)
)
else:
red1 = m1
left_no_right = set(left) - (set(left) & (set(right) | set(axes)))
if left_no_right:
red2 = m2.sum(axis=tuple(sorted(left_no_right)), keepdims=True)
if verbose:
print(
"[GENERICDOT] reducesumR=%r, %r -> %r"
% (left_no_right, m2.shape, red2.shape)
)
else:
red2 = m2
# Transpose
common_axes = sorted(set(left) & set(right))
i_axes = [
(-1 if i in common_axes else (1 if i in axes else 0), i)
for i in range(len(m1.shape))
]
i_axes.sort()
perm = [_[1] for _ in i_axes]
trm1 = numpy.transpose(red1, axes=perm)
trm2 = numpy.transpose(red2, axes=perm)
if verbose:
print(f"[GENERICDOT] transposeL={perm!r}, {red1.shape!r} -> {trm1.shape!r}")
print(f"[GENERICDOT] transposeR={perm!r}, {red2.shape!r} -> {trm2.shape!r}")
final_shape = numpy_extended_dot_ouput_shape(m1, m2, axes, left, right)
perm_left = [i for i in range(len(perm)) if perm[i] in left]
perm_right = [i for i in range(len(perm)) if perm[i] in right]
perm_common_axes = [i for i in range(len(perm)) if perm[i] in common_axes]
if verbose:
print(
"[GENERICDOT] MatMul %r @ %r -> %r -- %s"
% (
m1.shape,
m2.shape,
final_shape,
_numpy_extended_dot_equation(
len(m1.shape), len(m1.shape), axes, left, right
),
)
)
print(f"[GENERICDOT] axes={axes!r} left={left!r} right={right!r}")
print(
"[GENERICDOT] perm=%r perm_left=%r "
"perm_right=%r perm_common_axes=%r"
% (perm, perm_left, perm_right, perm_common_axes)
)
# Reshape
dim0 = int(numpy.prod([trm1.shape[i] for i in perm_common_axes]))
dim0b = int(numpy.prod([trm2.shape[i] for i in perm_common_axes]))
if len(axes) > 0:
all_axes = list(range(0, len(m1.shape)))
new_axes = all_axes[-len(axes) :]
else:
new_axes = []
dim1 = int(numpy.prod([trm1.shape[i] for i in new_axes]))
dim2 = int(numpy.prod([trm2.shape[i] for i in new_axes]))
if dim1 != dim2:
raise RuntimeError(
"Summation axis do not have the same length %d != %d, "
"trshape1=%r trshape2=%r "
"p_axes=%r p_left=%r p_right=%r p_common=%r"
"."
% (
dim1,
dim2,
trm1.shape,
trm2.shape,
new_axes,
perm_left,
perm_right,
perm_common_axes,
)
)
else:
shm1 = trm1.reshape((dim0, -1, dim1))
shm2 = trm2.reshape((dim0b, -1, dim2))
if verbose:
print(
"[GENERICDOT] Reshape %r @ %r -> %r @ %r"
% ((dim0, -1, dim1), (dim0, -1, dim2), shm1.shape, shm2.shape)
)
print("[GENERICDOT] matmul")
# Multiplication (this should be done in a different way.
res = shm1 @ numpy.transpose(shm2, axes=(0, 2, 1))
if verbose:
print(f"[GENERICDOT] Shape after multiplication {res.shape}")
# Transpose again
not_in_both = []
for i in range(0, len(m1.shape)):
if i not in left and i not in right:
not_in_both.append(i)
ordered_axes = (
common_axes
+ list(i for i in left if i not in right)
+ list(i for i in right if i not in left)
+ not_in_both
)
perm_not_in_both = [i for i in range(len(perm)) if perm[i] in not_in_both]
current_shape = (
[max(trm1.shape[i], trm2.shape[i]) for i in sorted(perm_common_axes)]
+ [trm1.shape[i] for i in sorted(perm_left) if i not in perm_common_axes]
+ [trm2.shape[i] for i in sorted(perm_right) if i not in perm_common_axes]
+ [1 for i in perm_not_in_both]
)
if verbose:
print(
"[GENERICDOT] current_shape=%r final_shape=%r "
"last_shape=%r" % (current_shape, final_shape, res.shape)
)
assert len(current_shape) == len(final_shape), (
"Shapes mismatch %r > %r, "
"shape1=%r shape2=%r axes=%r left=%r right=%r."
% (current_shape, final_shape, m1.shape, m2.shape, axes, left, right)
)
res = res.reshape(current_shape)
perm = [(a, i) for i, a in enumerate(ordered_axes)]
perm.sort()
perm = [p[1] for p in perm]
if verbose:
print(f"[GENERICDOT] ordered_axes={ordered_axes!r} perm={perm!r}")
return numpy.transpose(res, axes=perm)
else:
# Multiplication and Matrix multiplication at the same time.
l_axes = set(left) & set(axes)
r_axes = set(right) & set(axes)
if r_axes and not l_axes:
new_axes = list(a for a in axes if a not in right)
new_left = list(sorted(set(left) | r_axes))
if verbose:
eq1 = _numpy_extended_dot_equation(
len(m1.shape), len(m1.shape), axes, left, right
)
eq2 = _numpy_extended_dot_equation(
len(m1.shape),
len(m1.shape),
tuple(new_axes),
tuple(new_left),
tuple(right),
)
print(
"[GENERICDOT] replace left %r by %r axes %r by %r, "
"eq %r by %r" % (left, new_left, axes, new_axes, eq1, eq2)
)
return numpy_extended_dot_matrix(
m1, m2, tuple(new_axes), tuple(new_left), tuple(right), verbose=verbose
)
raise RuntimeError(
"shape1=%r shape2=%r axes=%r left=%r right=%r eq=%s."
% (
m1.shape,
m2.shape,
axes,
left,
right,
_numpy_extended_dot_equation(
len(m1.shape), len(m1.shape), axes, left, right
),
)
)