yobx.helpers._einsum.einsum_impl#

yobx.helpers._einsum.einsum_impl.analyse_einsum_equation(equation: str) Tuple[str, ndarray, List[int], List[Dict[str, List[int]] | None]][source]#

Analyses an einsum equation.

Parameters:

equationnumpy.einsum() equation

Returns:

four results, list of letters, a matrix (see below), lengths of each components, duplicates

The returned a matrix is defined as follows:

m_{ij}=\left\{\begin{array}{ll}-1 &
\text{if letter j is involved in input i} \\
p & \text{p is position of letter j in equation i}
\end{array}\right.

yobx.helpers._einsum.einsum_impl.apply_einsum_sequence(seq: GraphEinsumSubOp, *inputs: ndarray, verbose: bool = False, **kwargs: Dict[str, Any]) ndarray[source]#

Applies a sequence of operations on a list of inputs. The sequence of operations is produced by function decompose_einsum_equation().

Parameters:
  • seq – sequence of operations

  • inputs – inputs

  • verbose – verbosity

  • kwargs – additional parameters, see apply_sequence in GraphEinsumSubOp

Returns:

output

<<<

import numpy
from yobx.helpers._einsum import decompose_einsum_equation
from yobx.helpers._einsum.einsum_impl import apply_einsum_sequence

m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
m2 = numpy.arange(4).reshape((2, 2)) + 100
m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000

seq = decompose_einsum_equation("bac,cd,def->ebc")
res = apply_einsum_sequence(seq, m1, m2, m3)
print(res)

>>>

    [[[ 8866198  9864696]
      [12090270 13152928]]
    
     [[ 8883886  9884376]
      [12114390 13179168]]]
yobx.helpers._einsum.einsum_impl.decompose_einsum_equation(equation: str, *shapes: Tuple[int, ...], strategy: str = 'simple', clean: bool = False, verbose: bool = False) GraphEinsumSubOp[source]#

Decomposes an equation used in numpy.einsum() knowing the input shapes. It returns a sequence of operations to do to compute the results.

Parameters:
  • equation – a string

  • shapes – sequence of input shapes

  • strategy – there are different way to decompose the equation, this parameters defines the way to do it (see below)

  • clean – clean the unnecessary node in the graph

  • verbose – verbosity

Returns:

instance of GraphEinsumSubOp

About strategy:

  • ‘simple’: align all dimensions in the alphabetical order, some generic matrix multiplication remains implemented with numpy.einsum() but only with two matrices aligned on the same dimension (see numpy_extended_dot)

  • ‘numpy’: same as simple but the decomposition does not use numpy.einsum() anymore but only multiplication or matrix multiplication merged into a single operator called batch_dot (see numpy_extended_dot_matrix)

Available operations: expand_dims, transpose, matmul, reduce_sum, id, squeeze, diagonal. It analyses an equation and produces a graph where nodes are instance of class EinsumSubOp.

<<<

from yobx.helpers._einsum import decompose_einsum_equation

seq = decompose_einsum_equation("bac,cd,def->ebc")
for op in seq:
    print(op)

>>>

    EinsumSubOp('id', 0, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5)))
    EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5)))
    EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,))
    EinsumSubOp('id', 1, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5)))
    EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6)
    EinsumSubOp('id', 2, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2)))
    EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,))
    EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6)
    EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5)))
    EinsumSubOp('squeeze', EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5))), axes=(0, 3, 5))

It can be better displayed as the following.

digraph{ orientation=portrait; ranksep=0.25; nodesep=0.05; width=0.5; height=0.1; size=5; node [shape=record]; 0 [label="input 0\\nbac\\n[ 1 0 2 -1 -1 -1]"]; 134541894208496 [label="id\\n"]; 0 -> 134541894208496; 134541894208400 [label="expand_dims\\naxes=((3, 3), (3, 4), (3, 5))"]; 134541894208496 -> 134541894208400; 134541894208736 [label="transpose\\nperm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))"]; 134541894208400 -> 134541894208736; 134540087992464 [label="reduce_sum - I0\\naxes=(0,)" style=filled fillcolor=red]; 134541894208736 -> 134540087992464; 1 [label="input 1\\ncd\\n[-1 -1 0 1 -1 -1]"]; 134540087992080 [label="id\\n"]; 1 -> 134540087992080; 134540087991216 [label="expand_dims\\naxes=((0, 0), (0, 1), (2, 4), (2, 5))"]; 134540087992080 -> 134540087991216; 134540087992656 [label="matmul - I1\\naxes=() left=(1, 2) ndim=6 right=(2, 3)\\n~aBCdef,abCDef-\\=aBCDef" style=filled fillcolor=red]; 134540087992464 -> 134540087992656; 134540087991216 -> 134540087992656; 2 [label="input 2\\ndef\\n[-1 -1 -1 0 1 2]"]; 134540087992320 [label="id\\n"]; 2 -> 134540087992320; 134540087992800 [label="expand_dims\\naxes=((0, 0), (0, 1), (0, 2))"]; 134540087992320 -> 134540087992800; 134540087992992 [label="reduce_sum\\naxes=(5,)"]; 134540087992800 -> 134540087992992; 134540087993424 [label="matmul - I2\\naxes=(3,) left=(1, 2) ndim=6 right=(4,)\\n~aBCdef,abcdEf-\\=aBCEf" style=filled fillcolor=red]; 134540087992656 -> 134540087993424; 134540087992992 -> 134540087993424; 134541894208448 [label="transpose\\nperm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5))"]; 134540087993424 -> 134541894208448; 134540087993568 [label="squeeze - I-1\\naxes=(0, 3, 5)" style=filled fillcolor=red]; 134541894208448 -> 134540087993568; }

yobx.helpers._einsum.einsum_impl.is_transpose_identity(perm: Tuple[int, ...] | List[int] | ndarray) bool[source]#

Tells if the permutation perm does nothing (identity).

Parameters:

perm – permutation

Returns:

boolean