onnx_extended.tools.einsum#

Decomposition of Einsum into simple operations.

analyse_einsum_equation#

onnx_extended.tools.einsum.einsum_impl.analyse_einsum_equation(equation: str) Tuple[str, ndarray, List[int], List[Dict[str, List[int]] | None]][source]#

Analyses an einsum equation.

Parameters:

equationnumpy.einsum() equation

Returns:

four results, list of letters, a matrix (see below), lengths of each components, duplicates

The returned a matrix is defined as follows:

m_{ij}=\left\{\begin{array}{ll}-1 &
\text{if letter j is involved in input i} \\
p & \text{p is position of letter j in equation i}
\end{array}\right.

apply_einsum_sequence#

onnx_extended.tools.einsum.einsum_impl.apply_einsum_sequence(seq: List[ndarray], *inputs: List[EinsumSubOp], verbose: bool = False, **kwargs: Dict[str, Any]) ndarray[source]#

Applies a sequence of operations on a list of inputs. The sequence of operations is produced by function decompose_einsum_equation().

Parameters:
  • seq – sequence of operations

  • inputs – inputs

  • verbose – verbosity

  • kwargs – additional parameters, see apply_sequence in GraphEinsumSubOp

Returns:

output

<<<

import numpy
from onnx_extended.tools.einsum import decompose_einsum_equation, apply_einsum_sequence

m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
m2 = numpy.arange(4).reshape((2, 2)) + 100
m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000

seq = decompose_einsum_equation("bac,cd,def->ebc")
res = apply_einsum_sequence(seq, m1, m2, m3)
print(res)

>>>

    [[[ 8866198  9864696]
      [12090270 13152928]]
    
     [[ 8883886  9884376]
      [12114390 13179168]]]

CachedEinsum#

class onnx_extended.tools.einsum.einsum_fct.CachedEinsum(equation: str, runtime: str = 'batch_dot', opset: int | None = None, optimize: bool = False, dtype: ~typing.Any = <class 'numpy.float64'>, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None, key: int | None = None)[source]#

Stores all the necessary information to cache the preprocessing of a an einsum equation.

Parameters:
  • equation – numpy equation

  • runtime – see einsum

  • opset – ONNX opset

  • optimize – finds the best letter permutation

  • dtype – dtype

  • decompose – to decompose Einsum operator or to keep it as is

  • key – key used to cache this class

  • strategy – optimization strategy

  • verbose – displays progress information

The class creates the following attributes:

  • equation_ corresponding to the best equivalent equation

  • graph_: the corresponding graph returned by function

    decompose_einsum_equation

  • onnx_: if a conversion to onnx is used, stores the onnx graph

  • runtime_: a function used by __call__, calls the runtime

build()[source]#

Preprocesses the equation builds whatever is necessary to compute the result of the einsum equation.

static build_einsum(equation: str, runtime: str, opset: int, optimize: bool, dtype: Any, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None, key: int | None = None) CachedEinsum[source]#

Creates an instance of CachedEinsum.

build_onnx_einsum(input_names: List[str]) ModelProto[source]#

Builds an ONNX graph with a single einsum operator.

build_runtime()[source]#

Builds the runtime associated to the equation self.equation_.

default_inputs(N: int | None = None) List[ndarray][source]#

Returns default inputs (reshaped numpy.arange + 0.7i).

Parameters:

N – dimension (all dimension have the same size)

If N is None, N is given a size depending on the number of letters to avoid spending too much time on optimization.

compute_transposition_features#

onnx_extended.tools.einsum.einsum_ml.compute_transposition_features(shape: Tuple[int, ...], perm: Tuple[int, ...]) Dict[str, float][source]#

Given a shape and a permutation, computes many features used to predict the cost of the transposition.

Parameters:
  • shape – shape

  • perm – permutation

Returns:

dictionary of features

<<<

import pprint
from onnx_extended.tools.einsum.einsum_ml import compute_transposition_features

pprint.pprint(compute_transposition_features((3, 5, 7), (2, 1, 0)))

>>>

    {'CST_': -1,
     'begin': -1,
     'dbegin': 0,
     'dend': 0,
     'dim': 3,
     'discont': 2,
     'edit': 2,
     'end': -1,
     'end16': -16,
     'end32': -32,
     'ibegin16': -0.0,
     'ibegin2': -0.0,
     'ibegin32': -0.0,
     'ibegin4': -0.0,
     'ibegin64': -0.0,
     'ibegin8': -0.0,
     'iend16': -0.0,
     'iend2': -0.0,
     'iend32': -0.0,
     'iend4': -0.0,
     'iend64': -0.0,
     'iend8': -0.0,
     'middle': 105,
     'rbegin': -0.009523809523809525,
     'rdiscont': -0.01904761904761905,
     'redit': 0.6666666666666666,
     'rend': -0.009523809523809525,
     'rend16': -0.1523809523809524,
     'rend32': -0.3047619047619048,
     'rev': 1,
     'rmiddle': 1.0,
     'rot': 0,
     'size': 105}

decompose_einsum_equation#

onnx_extended.tools.einsum.einsum_impl.decompose_einsum_equation(equation: str, *shapes: List[Tuple[int, ...]], strategy: str = 'simple', clean: bool = False, verbose: bool = False) GraphEinsumSubOp[source]#

Decomposes an equation used in numpy.einsum() knowing the input shapes. It returns a sequence of operations to do to compute the results.

Parameters:
  • equation – a string

  • shapes – sequence of input shapes

  • strategy – there are different way to decompose the equation, this parameters defines the way to do it (see below)

  • clean – clean the unnecessary node in the graph

  • verbose – verbosity

Returns:

instance of GraphEinsumSubOp

About strategy:

  • ‘simple’: align all dimensions in the alphabetical order, some generic matrix multiplication remains implemented with numpy.einsum() but only with two matrices aligned on the same dimension (see numpy_extended_dot)

  • ‘numpy’: same as simple but the decomposition does not use numpy.einsum() anymore but only multiplication or matrix multiplication merged into a single operator called batch_dot (see numpy_extended_dot_matrix)

Available operations: expand_dims, transpose, matmul, reduce_sum, id, squeeze, diagonal. It analyses an equation and produces a graph where nodes are instance of class EinsumSubOp.

<<<

from onnx_extended.tools.einsum import decompose_einsum_equation

seq = decompose_einsum_equation("bac,cd,def->ebc")
for op in seq:
    print(op)

>>>

    EinsumSubOp('id', 0, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5)))
    EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5))
    EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,))
    EinsumSubOp('id', 1, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5)))
    EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6)
    EinsumSubOp('id', 2, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2)))
    EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,))
    EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6)
    EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(0, 4, 1, 3, 2, 5))
    EinsumSubOp('squeeze', EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(0, 4, 1, 3, 2, 5)), axes=(0, 3, 5))

It can be better displayed as the following.

digraph{
orientation=portrait;
ranksep=0.25;
nodesep=0.05;
width=0.5;
height=0.1;
size=5;
node [shape=record];
0 [label="input 0\\nbac\\n[ 1  0  2 -1 -1 -1]"];
139692820899504 [label="id\\nNone"];
0 -> 139692820899504;
139692820905504 [label="expand_dims\\naxes=((3, 3), (3, 4), (3, 5))None"];
139692820899504 -> 139692820905504;
139692821358784 [label="transpose\\nperm=(1, 0, 2, 3, 4, 5)None"];
139692820905504 -> 139692821358784;
139687590973680 [label="reduce_sum - I0\\naxes=(0,)None" style=filled fillcolor=red];
139692821358784 -> 139687590973680;
1 [label="input 1\\ncd\\n[-1 -1  0  1 -1 -1]"];
139687590973536 [label="id\\nNone"];
1 -> 139687590973536;
139687590973632 [label="expand_dims\\naxes=((0, 0), (0, 1), (2, 4), (2, 5))None"];
139687590973536 -> 139687590973632;
139687590974160 [label="matmul - I1\\naxes=() left=(1, 2) ndim=6 right=(2, 3)\\n~aBCdef,abCDef-\\=aBCDef" style=filled fillcolor=red];
139687590973680 -> 139687590974160;
139687590973632 -> 139687590974160;
2 [label="input 2\\ndef\\n[-1 -1 -1  0  1  2]"];
139687590974112 [label="id\\nNone"];
2 -> 139687590974112;
139687590974016 [label="expand_dims\\naxes=((0, 0), (0, 1), (0, 2))None"];
139687590974112 -> 139687590974016;
139687590974256 [label="reduce_sum\\naxes=(5,)None"];
139687590974016 -> 139687590974256;
139687590974736 [label="matmul - I2\\naxes=(3,) left=(1, 2) ndim=6 right=(4,)\\n~aBCdef,abcdEf-\\=aBCEf" style=filled fillcolor=red];
139687590974160 -> 139687590974736;
139687590974256 -> 139687590974736;
139687590974544 [label="transpose\\nperm=(0, 4, 1, 3, 2, 5)None"];
139687590974736 -> 139687590974544;
139687590974688 [label="squeeze - I-1\\naxes=(0, 3, 5)None" style=filled fillcolor=red];
139687590974544 -> 139687590974688;
}

einsum#

onnx_extended.tools.einsum.einsum_fct.einsum(equation: str, *inputs: List[ndarray], optimize: bool = False, runtime: str = 'batch_dot', cache: bool = True, opset: int | None = None, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None) ndarray[source]#

Proposes a new implementation of numpy.einsum(). It does not allow expresion using and expects a right member.

Parameters:
  • equation – einsum equation

  • inputs – inputs

  • optimize – permutes all letters to find the best permutation

  • runtime – runtime used to compute the results once the computation graph is produced (see below)

  • cache – if True, the function stores the preprocessing done for a specific equation, the second call with the same equation is much faster

  • opset – ONNX opset to use for some runtimes

  • decompose – by default, the function decomposes the equation into more simple operators but it can keep the original ONNX einsum operator.

  • strategy – optimisation strategy (see below)

  • verbose – display progress if optimize is True

Returns:

einsum result

The available runtimes are:

  • batch_dot: the runtime is apply_einsum_sequence,

  • python: one ONNX graph executed with a python runtime,

  • onnxruntime: one ONNX graph executed with onnxruntime.

The optimisation strategy can be:

  • None: the same runtime is used to find the best permutation of letters

  • ‘ml’: a machine learned model is used to predict the

    best permutation of letters.

The function works in two steps:

  • first step analyses the equation to produce a computation graph, this graph can also be converted into ONNX,

  • second step runs the graph whatever the graph is.

Further details are available in the documentation of function optimize_decompose_einsum_equation(). The function works the same way as numpy.einsum():

<<<

import numpy
from onnx_extended.tools.einsum import einsum

equation = "abc,cd->abd"

m1 = numpy.random.randn(2, 2, 2)
m2 = numpy.random.randn(2, 2)

np = numpy.einsum(equation, m1, m2)
print("numpy.einsum")
print(np)

print("onnx_extended.tools.einsum")
mp = einsum(equation, m1, m2)
print(mp)

>>>

    numpy.einsum
    [[[ 0.411 -0.389]
      [ 0.111 -3.012]]
    
     [[-0.17   1.699]
      [ 0.306  1.594]]]
    onnx_extended.tools.einsum
    [[[ 0.411 -0.389]
      [ 0.111 -3.012]]
    
     [[-0.17   1.699]
      [ 0.306  1.594]]]

In some case, the einsum implementation can be optimized by looping on possible permutation:

<<<

import timeit
import numpy
from onnx_extended.tools.einsum import einsum
from onnx_extended.tools.einsum.einsum_fct import enumerate_cached_einsum

equation = "cab,cd->ad"

m1 = numpy.random.randn(20, 20, 20)
m2 = numpy.random.randn(20, 20)

print(
    "numpy.einsum",
    timeit.timeit("numpy.einsum(equation, m1, m2)", number=200, globals=globals()),
)

einsum(equation, m1, m2)
print(
    "einsum", timeit.timeit("einsum(equation, m1, m2)", number=200, globals=globals())
)

einsum(equation, m1, m2, runtime="python")
print(
    "einsum-python",
    timeit.timeit(
        'einsum(equation, m1, m2, runtime="python")', number=200, globals=globals()
    ),
)

einsum(equation, m1, m2, runtime="onnxruntime")
print(
    "einsum-onnxruntime",
    timeit.timeit(
        'einsum(equation, m1, m2, runtime="onnxruntime")', number=200, globals=globals()
    ),
)

einsum(equation, m1, m2, runtime="onnxruntime", optimize=True, verbose=1)
print(
    "einsum-onnxruntime",
    timeit.timeit(
        'einsum(equation, m1, m2, runtime="onnxruntime", optimize=True)',
        number=200,
        globals=globals(),
    ),
)

print("list of cached einsum equations")
for k, v in enumerate_cached_einsum():
    print(k, v.equation, v.equation_)

>>>

    [2023-12-29 23:26:52,580] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    numpy.einsum 0.03852240000014717
    einsum 0.01917120000007344
    einsum-python 0.09263770000006843
    einsum-onnxruntime 0.014252599999963422
    einsum-onnxruntime 0.011129400000299938
    list of cached einsum equations
    ('cab,cd->ad', 'batch_dot', None, False, dtype('float64'), True, None) cab,cd->ad cab,cd->ad
    ('cab,cd->ad', 'python', None, False, dtype('float64'), True, None) cab,cd->ad cab,cd->ad
    ('cab,cd->ad', 'onnxruntime', None, False, dtype('float64'), True, None) cab,cd->ad cab,cd->ad
    ('cab,cd->ad', 'onnxruntime', None, True, dtype('float64'), True, None) cab,cd->ad cda,cb->db
    [runpythonerror]
    WARNING:2023-12-29 23:26:50 5845:5845 init.cpp:155] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_NOT_INITIALIZED (15)
    WARNING:2023-12-29 23:26:50 5845:5845 init.cpp:156] CUPTI initialization failed - CUDA profiler activities will be missing
    INFO:2023-12-29 23:26:50 5845:5845 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti
    
  0%|          | 0/25 [00:00<?, ?it/s]
0.001 rtbest='cab,cd->ad':   0%|          | 0/25 [00:00<?, ?it/s]
0.00085 rtbest='cab,cd->ad':   0%|          | 0/25 [00:00<?, ?it/s]
0.00038 rtbest='dab,dc->ac':   0%|          | 0/25 [00:00<?, ?it/s]
0.00037 rtbest='bad,bc->ac':   0%|          | 0/25 [00:00<?, ?it/s]
0.00037 rtbest='cad,cb->ab':   0%|          | 0/25 [00:00<?, ?it/s]
0.00037 rtbest='cad,cb->ab':  28%|██▊       | 7/25 [00:00<00:00, 63.85it/s]
0.00037 rtbest='cda,cb->db':  28%|██▊       | 7/25 [00:00<00:00, 63.85it/s]
0.00037 rtbest='cda,cb->db':  56%|█████▌    | 14/25 [00:00<00:00, 65.47it/s]
0.00037 rtbest='cda,cb->db':  88%|████████▊ | 22/25 [00:00<00:00, 69.04it/s]
0.00037 rtbest='cda,cb->db': 100%|██████████| 25/25 [00:00<00:00, 68.63it/s]

The last example shows the time taken by every function:

<<<

import logging
import os
import cProfile
from io import StringIO
from pstats import Stats
import numpy
from onnx_extended.tools.einsum import einsum
from onnx_extended.tools.einsum.einsum_fct import enumerate_cached_einsum
from onnx_extended import __file__ as path


def profile(fct, sort="cumulative", **kwargs):
    pr = cProfile.Profile(**kwargs)
    pr.enable()
    fct_res = fct()
    pr.disable()
    s = StringIO()
    ps = Stats(pr, stream=s).sort_stats(sort)
    ps.print_stats()
    res = s.getvalue()
    return ps, res


root = os.path.dirname(path)
logging.getLogger("onnx-extended").setLevel(logging.ERROR)

equation = "cab,cd->ad"

m1 = numpy.random.randn(200, 20, 20)
m2 = numpy.random.randn(200, 20)


def clean(txt):
    txt = txt.replace(root, "onnx_extended")
    return "\n".join(txt.split("\n")[:30])


def fct1():
    for i in range(100):
        einsum(equation, m1, m2, cache=False)


print("Profile cache with default runtime.")
res = profile(fct1)
print(root)
print(clean(res[1]))


def fct2():
    for i in range(100):
        einsum(equation, m1, m2, cache=False, runtime="python")


print("Profile cache with runtime='python'.")
res = profile(fct2)
print(root)
print(clean(res[1]))


def fct3():
    for i in range(100):
        einsum(equation, m1, m2, cache=True)


einsum(equation, m1, m2, cache=True)
print("Profile execution with default runtime.")
res = profile(fct3)
print(root)
print(clean(res[1]))


def fct4():
    for i in range(100):
        einsum(equation, m1, m2, cache=True, runtime="python")


einsum(equation, m1, m2, cache=True, runtime="python")
print("Profile execution with runtime='python'.")
res = profile(fct4)
print(root)
print(clean(res[1]))


def fct5():
    for i in range(100):
        einsum(equation, m1, m2, cache=True, runtime="onnxruntime")


einsum(equation, m1, m2, cache=True, runtime="onnxruntime")
print("Profile execution with runtime='onnxruntime'.")
res = profile(fct5)
print(root)
print(clean(res[1]))

>>>

    [2023-12-29 23:27:00,639] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    Profile cache with default runtime.
    /home/xadupre/github/onnx-extended/onnx_extended
             128702 function calls (128502 primitive calls) in 0.207 seconds
    
       Ordered by: cumulative time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.207    0.207 <stdin>:45(fct1)
          100    0.001    0.000    0.207    0.002 onnx_extended/tools/einsum/einsum_fct.py:691(einsum)
          100    0.000    0.000    0.148    0.001 onnx_extended/tools/einsum/einsum_fct.py:599(optimize_decompose_einsum_equation)
          100    0.000    0.000    0.148    0.001 onnx_extended/tools/einsum/einsum_fct.py:562(_einsum)
          100    0.000    0.000    0.147    0.001 onnx_extended/tools/einsum/einsum_fct.py:532(build_einsum)
          100    0.000    0.000    0.147    0.001 onnx_extended/tools/einsum/einsum_fct.py:305(build)
          100    0.000    0.000    0.147    0.001 onnx_extended/tools/einsum/einsum_fct.py:475(build_runtime)
          100    0.001    0.000    0.146    0.001 onnx_extended/tools/einsum/einsum_impl.py:79(decompose_einsum_equation)
          100    0.012    0.000    0.103    0.001 onnx_extended/tools/einsum/einsum_impl.py:415(_decompose_einsum_equation_simple)
          100    0.000    0.000    0.058    0.001 onnx_extended/tools/einsum/einsum_fct.py:525(__call__)
          100    0.000    0.000    0.058    0.001 onnx_extended/tools/einsum/einsum_fct.py:500(<lambda>)
          100    0.000    0.000    0.058    0.001 onnx_extended/tools/einsum/einsum_impl.py:163(apply_einsum_sequence)
          100    0.002    0.000    0.058    0.001 onnx_extended/tools/einsum/einsum_impl_classes.py:1390(apply_sequence)
         1200    0.004    0.000    0.055    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:717(apply)
         1200    0.008    0.000    0.041    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:412(compute_output_row)
          100    0.000    0.000    0.027    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1415(clean_unused_nodes)
          200    0.015    0.000    0.027    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1422(iteration)
         1900    0.016    0.000    0.016    0.000 {method 'reduce' of 'numpy.ufunc' objects}
          500    0.002    0.000    0.016    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:71(_wrapreduction)
          100    0.003    0.000    0.015    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:596(_apply_batch_dot)
         4800    0.005    0.000    0.015    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:19(single_axes)
          100    0.001    0.000    0.013    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:676(_apply_reduce_sum)
          100    0.000    0.000    0.012    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2177(sum)
          600    0.005    0.000    0.012    0.000 onnx_extended/tools/einsum/einsum_impl.py:296(_apply_einsum_matmul)
        22100    0.011    0.000    0.011    0.000 {built-in method builtins.isinstance}
    Profile cache with runtime='python'.
    /home/xadupre/github/onnx-extended/onnx_extended
             476555 function calls (471719 primitive calls) in 0.826 seconds
    
       Ordered by: cumulative time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.826    0.826 <stdin>:56(fct2)
          100    0.001    0.000    0.826    0.008 onnx_extended/tools/einsum/einsum_fct.py:691(einsum)
          100    0.000    0.000    0.588    0.006 onnx_extended/tools/einsum/einsum_fct.py:599(optimize_decompose_einsum_equation)
          100    0.000    0.000    0.588    0.006 onnx_extended/tools/einsum/einsum_fct.py:562(_einsum)
          100    0.000    0.000    0.588    0.006 onnx_extended/tools/einsum/einsum_fct.py:532(build_einsum)
          100    0.000    0.000    0.587    0.006 onnx_extended/tools/einsum/einsum_fct.py:305(build)
          100    0.002    0.000    0.587    0.006 onnx_extended/tools/einsum/einsum_fct.py:475(build_runtime)
          100    0.001    0.000    0.251    0.003 onnx_extended/reference/c_reference_evaluator.py:227(__init__)
          100    0.000    0.000    0.237    0.002 onnx_extended/tools/einsum/einsum_fct.py:525(__call__)
          100    0.001    0.000    0.237    0.002 onnx_extended/tools/einsum/einsum_fct.py:512(<lambda>)
          100    0.031    0.000    0.236    0.002 onnx_extended/reference/c_reference_evaluator.py:258(run)
          100    0.004    0.000    0.228    0.002 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:202(__init__)
          100    0.010    0.000    0.213    0.002 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:398(_init)
          100    0.013    0.000    0.180    0.002 onnx_extended/tools/einsum/einsum_impl_classes.py:1652(to_onnx)
          100    0.001    0.000    0.151    0.002 onnx_extended/tools/einsum/einsum_impl.py:79(decompose_einsum_equation)
         2300    0.021    0.000    0.144    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:426(run)
         5000    0.014    0.000    0.112    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1144(to_onnx)
         2800    0.014    0.000    0.108    0.000 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:454(_load_impl)
          100    0.012    0.000    0.108    0.001 onnx_extended/tools/einsum/einsum_impl.py:415(_decompose_einsum_equation_simple)
        212/9    0.001    0.000    0.077    0.009 <frozen importlib._bootstrap>:1022(_find_and_load)
        212/9    0.001    0.000    0.076    0.008 <frozen importlib._bootstrap>:987(_find_and_load_unlocked)
       212/12    0.001    0.000    0.075    0.006 <frozen importlib._bootstrap>:664(_load_unlocked)
       208/12    0.001    0.000    0.074    0.006 <frozen importlib._bootstrap_external>:877(exec_module)
       220/12    0.000    0.000    0.074    0.006 <frozen importlib._bootstrap>:233(_call_with_frames_removed)
       211/12    0.001    0.000    0.073    0.006 {built-in method builtins.exec}
    Profile execution with default runtime.
    /home/xadupre/github/onnx-extended/onnx_extended
             32202 function calls in 0.057 seconds
    
       Ordered by: cumulative time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.057    0.057 <stdin>:67(fct3)
          100    0.001    0.000    0.057    0.001 onnx_extended/tools/einsum/einsum_fct.py:691(einsum)
          100    0.000    0.000    0.056    0.001 onnx_extended/tools/einsum/einsum_fct.py:525(__call__)
          100    0.000    0.000    0.055    0.001 onnx_extended/tools/einsum/einsum_fct.py:500(<lambda>)
          100    0.000    0.000    0.055    0.001 onnx_extended/tools/einsum/einsum_impl.py:163(apply_einsum_sequence)
          100    0.002    0.000    0.055    0.001 onnx_extended/tools/einsum/einsum_impl_classes.py:1390(apply_sequence)
         1200    0.003    0.000    0.052    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:717(apply)
          500    0.002    0.000    0.015    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:71(_wrapreduction)
          100    0.003    0.000    0.014    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:596(_apply_batch_dot)
          100    0.001    0.000    0.013    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:676(_apply_reduce_sum)
          500    0.012    0.000    0.012    0.000 {method 'reduce' of 'numpy.ufunc' objects}
          100    0.000    0.000    0.011    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2177(sum)
          200    0.001    0.000    0.008    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:503(_apply_expand_dims)
          400    0.002    0.000    0.008    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:516(_apply_transpose)
          300    0.002    0.000    0.006    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/lib/shape_base.py:512(expand_dims)
         1300    0.003    0.000    0.006    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:463(_get_data)
          400    0.001    0.000    0.005    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2979(prod)
         4500    0.002    0.000    0.002    0.000 {built-in method builtins.len}
         1300    0.002    0.000    0.002    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:455(_check_shape_)
         1200    0.001    0.000    0.002    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:446(_check_inputs_)
          300    0.001    0.000    0.002    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/numeric.py:1330(normalize_axis_tuple)
          400    0.001    0.000    0.002    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:588(transpose)
          300    0.001    0.000    0.002    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:478(_apply_id)
          100    0.001    0.000    0.002    0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:703(_apply_squeeze)
          100    0.002    0.000    0.002    0.000 onnx_extended/tools/einsum/blas_lapack.py:76(gemm_dot)
    Profile execution with runtime='python'.
    /home/xadupre/github/onnx-extended/onnx_extended
             145502 function calls in 0.241 seconds
    
       Ordered by: cumulative time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.241    0.241 <stdin>:79(fct4)
          100    0.001    0.000    0.241    0.002 onnx_extended/tools/einsum/einsum_fct.py:691(einsum)
          100    0.000    0.000    0.239    0.002 onnx_extended/tools/einsum/einsum_fct.py:525(__call__)
          100    0.001    0.000    0.239    0.002 onnx_extended/tools/einsum/einsum_fct.py:512(<lambda>)
          100    0.032    0.000    0.238    0.002 onnx_extended/reference/c_reference_evaluator.py:258(run)
         2300    0.021    0.000    0.146    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:426(run)
         3400    0.015    0.000    0.070    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:394(_check_and_fix_outputs)
          400    0.002    0.000    0.029    0.000 /home/xadupre/github/onnx/onnx/reference/ops/_op.py:48(run)
         6800    0.012    0.000    0.027    0.000 {built-in method builtins.any}
         6800    0.007    0.000    0.026    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:411(<genexpr>)
        20900    0.014    0.000    0.020    0.000 {built-in method builtins.isinstance}
         3400    0.007    0.000    0.019    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/numeric.py:1855(isscalar)
          400    0.002    0.000    0.019    0.000 /home/xadupre/github/onnx/onnx/reference/ops/_op.py:22(run)
         5600    0.008    0.000    0.019    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:227(_log)
          100    0.001    0.000    0.014    0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_reduce_sum.py:22(_run)
          100    0.000    0.000    0.012    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2177(sum)
          100    0.001    0.000    0.012    0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:71(_wrapreduction)
          100    0.011    0.000    0.011    0.000 {method 'reduce' of 'numpy.ufunc' objects}
         5600    0.008    0.000    0.011    0.000 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:409(<lambda>)
          300    0.001    0.000    0.010    0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_reshape.py:34(_run)
          300    0.004    0.000    0.009    0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_reshape.py:11(reshape_reference_implementation)
          100    0.000    0.000    0.008    0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_max.py:15(run)
         6800    0.006    0.000    0.008    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:414(<genexpr>)
          100    0.000    0.000    0.008    0.000 /home/xadupre/github/onnx/onnx/reference/ops/_op.py:117(run)
         6800    0.006    0.000    0.008    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:404(<genexpr>)
    Profile execution with runtime='onnxruntime'.
    /home/xadupre/github/onnx-extended/onnx_extended
             1602 function calls in 0.031 seconds
    
       Ordered by: cumulative time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    0.031    0.031 <stdin>:91(fct5)
          100    0.001    0.000    0.030    0.000 onnx_extended/tools/einsum/einsum_fct.py:691(einsum)
          100    0.000    0.000    0.029    0.000 onnx_extended/tools/einsum/einsum_fct.py:525(__call__)
          100    0.000    0.000    0.029    0.000 onnx_extended/tools/einsum/einsum_fct.py:512(<lambda>)
          100    0.028    0.000    0.028    0.000 /home/xadupre/github/onnxruntime/build/linux_cuda/Release/onnxruntime/capi/onnxruntime_inference_collection.py:202(run)
          100    0.000    0.000    0.000    0.000 onnx_extended/tools/einsum/einsum_fct.py:599(optimize_decompose_einsum_equation)
          100    0.000    0.000    0.000    0.000 onnx_extended/tools/einsum/einsum_fct.py:562(_einsum)
          100    0.000    0.000    0.000    0.000 /home/xadupre/github/onnxruntime/build/linux_cuda/Release/onnxruntime/capi/onnxruntime_inference_collection.py:192(_validate_input)
          300    0.000    0.000    0.000    0.000 onnx_extended/tools/einsum/einsum_fct.py:931(<genexpr>)
          100    0.000    0.000    0.000    0.000 /home/xadupre/github/onnxruntime/build/linux_cuda/Release/onnxruntime/capi/onnxruntime_inference_collection.py:218(<listcomp>)
          100    0.000    0.000    0.000    0.000 onnx_extended/tools/einsum/einsum_fct.py:513(<dictcomp>)
          100    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
          100    0.000    0.000    0.000    0.000 {built-in method builtins.len}
          100    0.000    0.000    0.000    0.000 {method 'keys' of 'dict' objects}
          100    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
            1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    [runpythonerror]
    WARNING:2023-12-29 23:26:58 5972:5972 init.cpp:155] function cbapi->getCuptiStatus() failed with error CUPTI_ERROR_NOT_INITIALIZED (15)
    WARNING:2023-12-29 23:26:58 5972:5972 init.cpp:156] CUPTI initialization failed - CUDA profiler activities will be missing
    INFO:2023-12-29 23:26:58 5972:5972 init.cpp:158] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti

einsum_benchmark#

onnx_extended.tools.einsum.einsum_bench.einsum_benchmark(equation: str = 'abc,cd->abd', shape: int = 30, perm: bool = False, runtime: str = 'python', use_tqdm: bool = False, number: int = 5, repeat: int = 5, opset=18) Iterable[Dict[str, str | float]][source]#

Investigates whether or not the decomposing einsum is faster.

Parameters:
  • equation – einsum equation to test

  • shape – an integer (all dimension gets the same size) or a list of shapes in a string separated with ;)

  • perm – check on permutation or all letter permutations

  • runtime – a string among ‘numpy’, ‘python’, ‘onnxruntime’

  • use_tqdm – show progress

  • number – usual parameter to measure a function

  • repeat – usual parameter to measure a function

  • opset – target opset

Returns:

list of dictionaries as an iterator

numpy_extended_dot#

onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot(m1: ndarray, m2: ndarray, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...], verbose: bool = False) ndarray[source]#

Extended version of a matrix multiplication (numpy.dot()) with two matrices m1, m2 of the same dimensions. Loops over left axes for m1 and right axes for m2, summation is done over axes. Other axes must be empty. This multiplication combines matrix multiplication (dot) and broadcasted multiplication term by term.

Parameters:
  • m1 – first matrix

  • m2 – second matrix

  • axes – summation axes

  • left – left axes

  • right – right axes

  • verbose – display intermediate information

Returns:

output

The dot product is equivalent to:

<<<

import numpy
from onnx_extended.tools.einsum import numpy_extended_dot

m1 = numpy.arange(4).reshape((2, 2))
m2 = m1 + 10
print("dot product")
print(m1 @ m2)

dm1 = m1.reshape((2, 2, 1))
dm2 = m2.reshape((1, 2, 2))
dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2], verbose=True)
print("extended dot product")
print(dot)

>>>

    dot product
    [[12 13]
     [56 61]]
      [numpy_extended_dot] Abc,abC->AC: (2, 2, 1) @ (1, 2, 2)
      [numpy_extended_dot] (2, 2) reshaped into [2, 1, 2] 
    extended dot product
    [[[12 13]]
    
     [[56 61]]]

Empty axes should be squeezed to get identical results. Dot product when the second matrix is transposed.

<<<

import numpy
from onnx_extended.tools.einsum import numpy_extended_dot

m1 = numpy.arange(4).reshape((2, 2))
m2 = m1 + 10
print("dot product")
print(m1 @ m2.T)

dm1 = m1.reshape((2, 1, 2))
dm2 = m2.reshape((1, 2, 2))
dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1], verbose=True)
print("extended dot product")
print(dot)

>>>

    dot product
    [[11 13]
     [53 63]]
      [numpy_extended_dot] Abc,aBc->AB: (2, 1, 2) @ (1, 2, 2)
      [numpy_extended_dot] (2, 2) reshaped into [2, 2, 1] 
    extended dot product
    [[[11]
      [13]]
    
     [[53]
      [63]]]

An example when right axes include the summation axis.

<<<

import numpy
from onnx_extended.tools.einsum import numpy_extended_dot

m1 = numpy.arange(4).reshape((2, 2))
m2 = m1 + 10
dm1 = m1.reshape((2, 2, 1))
dm2 = m2.reshape((1, 2, 2))
dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2], verbose=True)
print(dot)

>>>

      [numpy_extended_dot] Abc,aBc->ABc: (2, 2, 1) @ (1, 2, 2)
      [numpy_extended_dot] (2, 2, 2) reshaped into [2, 2, 2] 
    [[[10 11]
      [12 13]]
    
     [[50 55]
      [60 65]]]

Example in higher dimension:

<<<

import numpy
from onnx_extended.tools.einsum import numpy_extended_dot

m1 = numpy.arange(8).reshape((2, 2, 2))
m2 = m1 + 10

dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True)
print(dot)

>>>

      [numpy_extended_dot] Abc,abC->AC: (2, 2, 2) @ (2, 2, 2)
      [numpy_extended_dot] (2, 2) reshaped into [2, 1, 2] 
    [[[164 176]]
    
     [[580 624]]]

The current implementation still uses numpy.einsum() but this should be replaced.

numpy_extended_dot_matrix#

onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_matrix(m1: ndarray, m2: ndarray, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...], verbose: bool = False) ndarray[source]#

Implementation of numpy_extended_dot using dot product, multiplication, transpose and reduction but not a custom python implementation like numpy_extended_dot_python.

<<<

import numpy
from onnx_extended.tools.einsum import numpy_extended_dot_matrix
from onnx_extended.tools.einsum.einsum_impl_ext import _numpy_extended_dot_equation

a = numpy.arange(6).reshape((3, 2, 1))
b = numpy.arange(12).reshape((3, 1, 4))

print(numpy_extended_dot_matrix(a, b, axes=(0,), left=(1,), right=(2,)))

# Equivalent einsum equation
print(
    "equation",
    _numpy_extended_dot_equation(
        len(a.shape), len(a.shape), axes=(0,), left=(1,), right=(2,)
    ),
)

# Same einsum computation written in a different way.
print(numpy.einsum("kix,kxj->xij", a, b))

>>>

    [[[40 46 52 58]
      [52 61 70 79]]]
    equation aBc,abC->BC
    [[[40 46 52 58]
      [52 61 70 79]]]

numpy_extended_dot_python#

onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_python(m1: ndarray, m2: ndarray, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...], verbose: bool = False) ndarray[source]#

Implementation of numpy_extended_dot in pure python. This implementation is not efficient but shows how to implement this operation without numpy.einsum().

<<<

import numpy
from onnx_extended.tools.einsum import numpy_extended_dot_python
from onnx_extended.tools.einsum.einsum_impl_ext import _numpy_extended_dot_equation

a = numpy.arange(6).reshape((3, 2, 1))
b = numpy.arange(12).reshape((3, 1, 4))

print(numpy_extended_dot_python(a, b, axes=(0,), left=(1,), right=(2,)))

# Equivalent einsum equation
print(
    "equation",
    _numpy_extended_dot_equation(
        len(a.shape), len(a.shape), axes=(0,), left=(1,), right=(2,)
    ),
)

# Same einsum computation written in a different way.
print(numpy.einsum("kix,kxj->xij", a, b))

>>>

    [[[40 46 52 58]
      [52 61 70 79]]]
    equation aBc,abC->BC
    [[[40 46 52 58]
      [52 61 70 79]]]

EinsumSubOp#

class onnx_extended.tools.einsum.einsum_impl_classes.EinsumSubOp(full_dim: int, name: str, *inputs: List[EinsumSubOp], **kwargs: Dict[str, Any])[source]#

Defines a sub operation used in Einsum decomposition.

Parameters:
  • full_dim – dimension of the result

  • name – name (reshape, transpose, reduce_sum, matmul, id, squeeze, diagonal, mul, batch_dot)

  • inputs – inputs

  • kwargs – arguments

Operator suffixed by _mm (transpose_mm, reduce_sum_mm) are equivalent to the same operator without the suffix but takes two inputs and only changes the first one.

Attributes _info summarizes the known information about dimensions. Many of them are empty because inserted. Value 1 means it was the case, 2 means it is a plain dimension.

add_info(**kwargs: Dict[str, Any])[source]#

Adds information to the node.

Parameters:

kwargs – dictionary

apply(data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]) Any[source]#

Applies one operator on the data.

Parameters:
  • data – dictionary storing the results

  • verbose – prints out intermediate results

  • kwargs – additional parameters, see methods _apply*

Returns:

output

Known additional paramaters:

compute_output_row(row: ndarray, row2: ndarray | None = None, ab: bool = False, verbose: bool = False)[source]#

Updates row based on the operator.

dot_label() str | None[source]#

Displays some informations useful to understand the operator.

get_dot_kind() str[source]#

Every matrix multiplication can be either:

  • a simple multiplication (M) (undetected)

  • a 2D matrix multiplication (11)

  • a broadcasted matrix multiplication (N1 or 1N)

  • a batch matrix multiplication (NN)

This method returns which kind it is.

to_onnx(names: List[str], opset: int | None, verbose: bool = False, **kwargs: Dict[str, Any]) Iterable[NodeProto][source]#

Converts this node into ONNX. Enumerates all ONNX node which participate to the conversion. The last one is the final output.

Parameters:
  • names – dictionary where to find already converted name

  • opset – opset

  • verbose – prints out intermediate results

  • kwargs – additional parameter for the conversion

Returns:

output

GraphEinsumSubOp#

class onnx_extended.tools.einsum.einsum_impl_classes.GraphEinsumSubOp(letters: str, mat: ndarray, lengths: List[int], duplicates: List[Dict[str, int]])[source]#

Class gathering all nodes produced to explicit einsum operators.

Parameters:
append(op: int | EinsumSubOp) EinsumSubOp | None[source]#

Adds one input or result.

Parameters:

op – integer (an input) or an instance of EinsumSubOp.

Returns:

op or None if op is an integer

apply_sequence(*inputs: List[EinsumSubOp], verbose: bool = False, **kwargs: Dict[str, Any]) Any[source]#

Applies a sequence of operations on a list of inputs.

Parameters:
  • inputs – inputs

  • verbose – prints out intermediate results

  • kwargs – additional parameters, see apply.

Returns:

output

clean_unused_nodes(verbose: bool = False)[source]#

Cleans nodes with unused outputs.

Parameters:

verbose – display intermediate information

mark(i: int, op: EinsumSubOp)[source]#

Marks one input or result as an intermediate result after a full einsum step.

Parameters:
mark_last_node()[source]#

Marks the last node as the final output.

remove_duplicate_transpose(verbose: bool = False)[source]#

Removes consecutive transpose by merging them.

Parameters:

verbose – display intermediate information

simplify_mm_nodes(verbose: bool = False)[source]#

Node name suffixed by mm are an artifact to keep the graph consistent while building it. They can now be replaced by the equivalent node without suffix mm.

Parameters:

verbose – display intermediate information

to_dot(**kwargs: Dict[str, Any]) str[source]#

Produces a graph in dot.

Parameters:

kwargs – additional graph option

Returns:

string

to_onnx(output: str, *inputs: List[str], dtype: Any | None = None, verbose: bool = False, opset: int | None = None, **kwargs: Dict[str, Any]) ModelProto[source]#

Converts the graph into ONNX.

Parameters:
  • output – output name

  • inputs – input names

  • dtype – type used for all operators

  • opset – desired opset, None for the last one

  • verbose – display intermediate operators

  • kwargs – additional parameter to use when building the ONNX graph, list of supported parameters: name, ir_version, producer_name, producer_version, initializer

Returns:

ONNX graph

Not all graphs can be converted into ONNX. Only graphs produced with strategy=’numpy’ can be converted otherwise the following error shows up:

NotImplementedError: to_onnx not implemented for 'matmul'.

OnnxMicroRuntime#

class onnx_extended.tools.einsum.einsum_fct.OnnxMicroRuntime(model_onnx)[source]#

Implements a micro runtime for ONNX graphs. It does not implements all the operator types.

Parameters:

model_onnx – ONNX model

property input_names#

Returns input names.

property output_names#

Returns output names.

run(unused: List[str] | None, inputs: Dict[str, Any]) Dict[str, Any][source]#

Computes the outputs of the graph.

Parameters:
  • unused – unused (the list of desired outputs)

  • inputs – dictionary

Returns:

all intermediates results and output as a dictionary

optimize_decompose_einsum_equation#

onnx_extended.tools.einsum.einsum_fct.optimize_decompose_einsum_equation(equation: str, dtype: Any, optimize: bool = False, runtime: str = 'batch_dot', cache: bool = True, opset: int | None = None, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None) CachedEinsum[source]#

Proposes a new implementation of numpy.einsum(). It does not allow expresion using and expects a right member.

Parameters:
  • equation – einsum equation

  • dtype – numpy dtype used for the computation

  • optimize – permutes all letters to find the best permutation

  • runtime – runtime used to compute the results once the computation graph is produced (see below)

  • cache – if True, the function stores the preprocessing done for a specific equation, the second call with the same equation is much faster

  • opset – ONNX opset to use for some runtimes

  • decompose – by default, the function decomposes the equation into more simple operators but it can keep the original ONNX einsum operator.

  • strategy – optimisation strategy (see below)

  • verbose – display progress if optimize is True

Returns:

einsum result

The available runtimes are:

  • batch_dot: the runtime is apply_einsum_sequence,

  • python: one ONNX graph executed with a python runtime,

  • onnxruntime: one ONNX graph executed with onnxruntime.

The optimisation strategy can be:

  • None: the same runtime is used to find the best permutation of letters

  • ‘ml’: a machine learned model is used to predict the

    best permutation of letters.

The function works in two steps:

  • first step analyses the equation to produce a computation graph, this graph can also be converted into ONNX,

  • second step runs the graph whatever the graph is.

The function returns an object of type CachedEinsum which has the following members after optimization:

  • equation_ corresponding to the best equivalent equation

  • graph_: the corresponding graph returned by function

    decompose_einsum_equation

  • onnx_: if a conversion to onnx is used, stores the onnx graph

  • runtime_: a function used by __call__, calls the runtime

  • oinf_: an object of type CReferenceEvaluator

  • timed_permutations_: memorizes the results of the optimization

<<<

import numpy
from onnx_extended.tools.einsum import optimize_decompose_einsum_equation

seq_opt = optimize_decompose_einsum_equation(
    "bsnh,btnh->bnts",
    numpy.float64,
    strategy="ml",
    verbose=1,
    runtime="python",
    optimize=True,
)

print("best equation:", seq_opt.equation_)

>>>

    
  0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bsnh,btnh->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bnth,bsth->btsn':   0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bnht,bsht->bhsn':   0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bhtn,bstn->btsh':   0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bhts,bnts->btnh':   0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bhts,bnts->btnh':  20%|#9        | 24/121 [00:00<00:00, 235.45it/s]
4.5 mlbest='bhts,bnts->btnh':  50%|####9     | 60/121 [00:00<00:00, 307.83it/s]
4.5 mlbest='bhts,bnts->btnh':  79%|#######9  | 96/121 [00:00<00:00, 330.15it/s]
4.5 mlbest='bhts,bnts->btnh': 100%|##########| 121/121 [00:00<00:00, 324.13it/s]
    best equation: bhts,bnts->btnh

predict_transposition_cost#

onnx_extended.tools.einsum.einsum_ml.predict_transposition_cost(shape: Tuple[int, ...], perm: Tuple[int, ...], coefs: Dict[str, float] | None = None) float[source]#

Given a shape and a permutation, predicts the cost of the transposition.

Parameters:
  • shape – shape

  • perm – permutation

  • coefs – trained coefficients or None to get the default ones

Returns:

dictionary of features

<<<

import pprint
from onnx_extended.tools.einsum.einsum_ml import compute_transposition_features

pprint.pprint(compute_transposition_features((3, 5, 7), (2, 1, 0)))

>>>

    {'CST_': -1,
     'begin': -1,
     'dbegin': 0,
     'dend': 0,
     'dim': 3,
     'discont': 2,
     'edit': 2,
     'end': -1,
     'end16': -16,
     'end32': -32,
     'ibegin16': -0.0,
     'ibegin2': -0.0,
     'ibegin32': -0.0,
     'ibegin4': -0.0,
     'ibegin64': -0.0,
     'ibegin8': -0.0,
     'iend16': -0.0,
     'iend2': -0.0,
     'iend32': -0.0,
     'iend4': -0.0,
     'iend64': -0.0,
     'iend8': -0.0,
     'middle': 105,
     'rbegin': -0.009523809523809525,
     'rdiscont': -0.01904761904761905,
     'redit': 0.6666666666666666,
     'rend': -0.009523809523809525,
     'rend16': -0.1523809523809524,
     'rend32': -0.3047619047619048,
     'rev': 1,
     'rmiddle': 1.0,
     'rot': 0,
     'size': 105}