onnx_extended.tools.einsum¶
Decomposition of Einsum into simple operations.
analyse_einsum_equation¶
- onnx_extended.tools.einsum.einsum_impl.analyse_einsum_equation(equation: str) Tuple[str, ndarray, List[int], List[Dict[str, List[int]] | None]] [source]¶
Analyses an einsum equation.
- Parameters:
equation –
numpy.einsum()
equation- Returns:
four results, list of letters, a matrix (see below), lengths of each components, duplicates
The returned a matrix is defined as follows:
apply_einsum_sequence¶
- onnx_extended.tools.einsum.einsum_impl.apply_einsum_sequence(seq: List[ndarray], *inputs: List[EinsumSubOp], verbose: bool = False, **kwargs: Dict[str, Any]) ndarray [source]¶
Applies a sequence of operations on a list of inputs. The sequence of operations is produced by function
decompose_einsum_equation()
.- Parameters:
seq – sequence of operations
inputs – inputs
verbose – verbosity
kwargs – additional parameters, see apply_sequence in
GraphEinsumSubOp
- Returns:
output
<<<
import numpy from onnx_extended.tools.einsum import decompose_einsum_equation, apply_einsum_sequence m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 m2 = numpy.arange(4).reshape((2, 2)) + 100 m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000 seq = decompose_einsum_equation("bac,cd,def->ebc") res = apply_einsum_sequence(seq, m1, m2, m3) print(res)
>>>
[[[ 8866198 9864696] [12090270 13152928]] [[ 8883886 9884376] [12114390 13179168]]]
CachedEinsum¶
- class onnx_extended.tools.einsum.einsum_fct.CachedEinsum(equation: str, runtime: str = 'batch_dot', opset: int | None = None, optimize: bool = False, dtype: ~typing.Any = <class 'numpy.float64'>, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None, key: int | None = None)[source]¶
Stores all the necessary information to cache the preprocessing of a an einsum equation.
- Parameters:
equation – numpy equation
runtime – see
einsum
opset – ONNX opset
optimize – finds the best letter permutation
dtype – dtype
decompose – to decompose Einsum operator or to keep it as is
key – key used to cache this class
strategy – optimization strategy
verbose – displays progress information
The class creates the following attributes:
equation_ corresponding to the best equivalent equation
- graph_: the corresponding graph returned by function
onnx_: if a conversion to onnx is used, stores the onnx graph
runtime_: a function used by __call__, calls the runtime
- build()[source]¶
Preprocesses the equation builds whatever is necessary to compute the result of the einsum equation.
- static build_einsum(equation: str, runtime: str, opset: int, optimize: bool, dtype: Any, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None, key: int | None = None) CachedEinsum [source]¶
Creates an instance of CachedEinsum.
- build_onnx_einsum(input_names: List[str]) ModelProto [source]¶
Builds an ONNX graph with a single einsum operator.
compute_transposition_features¶
- onnx_extended.tools.einsum.einsum_ml.compute_transposition_features(shape: Tuple[int, ...], perm: Tuple[int, ...]) Dict[str, float] [source]¶
Given a shape and a permutation, computes many features used to predict the cost of the transposition.
- Parameters:
shape – shape
perm – permutation
- Returns:
dictionary of features
<<<
import pprint from onnx_extended.tools.einsum.einsum_ml import compute_transposition_features pprint.pprint(compute_transposition_features((3, 5, 7), (2, 1, 0)))
>>>
{'CST_': -1, 'begin': -1, 'dbegin': 0, 'dend': 0, 'dim': 3, 'discont': 2, 'edit': 2, 'end': -1, 'end16': -16, 'end32': -32, 'ibegin16': -0.0, 'ibegin2': -0.0, 'ibegin32': -0.0, 'ibegin4': -0.0, 'ibegin64': -0.0, 'ibegin8': -0.0, 'iend16': -0.0, 'iend2': -0.0, 'iend32': -0.0, 'iend4': -0.0, 'iend64': -0.0, 'iend8': -0.0, 'middle': 105, 'rbegin': -0.009523809523809525, 'rdiscont': -0.01904761904761905, 'redit': 0.6666666666666666, 'rend': -0.009523809523809525, 'rend16': -0.1523809523809524, 'rend32': -0.3047619047619048, 'rev': 1, 'rmiddle': 1.0, 'rot': 0, 'size': 105}
decompose_einsum_equation¶
- onnx_extended.tools.einsum.einsum_impl.decompose_einsum_equation(equation: str, *shapes: List[Tuple[int, ...]], strategy: str = 'simple', clean: bool = False, verbose: bool = False) GraphEinsumSubOp [source]¶
Decomposes an equation used in
numpy.einsum()
knowing the input shapes. It returns a sequence of operations to do to compute the results.- Parameters:
equation – a string
shapes – sequence of input shapes
strategy – there are different way to decompose the equation, this parameters defines the way to do it (see below)
clean – clean the unnecessary node in the graph
verbose – verbosity
- Returns:
instance of
GraphEinsumSubOp
About strategy:
‘simple’: align all dimensions in the alphabetical order, some generic matrix multiplication remains implemented with
numpy.einsum()
but only with two matrices aligned on the same dimension (seenumpy_extended_dot
)‘numpy’: same as simple but the decomposition does not use
numpy.einsum()
anymore but only multiplication or matrix multiplication merged into a single operator called batch_dot (seenumpy_extended_dot_matrix
)
Available operations: expand_dims, transpose, matmul, reduce_sum, id, squeeze, diagonal. It analyses an equation and produces a graph where nodes are instance of class
EinsumSubOp
.<<<
from onnx_extended.tools.einsum import decompose_einsum_equation seq = decompose_einsum_equation("bac,cd,def->ebc") for op in seq: print(op)
>>>
EinsumSubOp('id', 0, ) EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))) EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)) EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)) EinsumSubOp('id', 1, ) EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))) EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6) EinsumSubOp('id', 2, ) EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))) EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)) EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6) EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(0, 4, 1, 3, 2, 5)) EinsumSubOp('squeeze', EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(1, 0, 2, 3, 4, 5)), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(0, 4, 1, 3, 2, 5)), axes=(0, 3, 5))
It can be better displayed as the following.
einsum¶
- onnx_extended.tools.einsum.einsum_fct.einsum(equation: str, *inputs: List[ndarray], optimize: bool = False, runtime: str = 'batch_dot', cache: bool = True, opset: int | None = None, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None) ndarray [source]¶
Proposes a new implementation of
numpy.einsum()
. It does not allow expresion using … and expects a right member.- Parameters:
equation – einsum equation
inputs – inputs
optimize – permutes all letters to find the best permutation
runtime – runtime used to compute the results once the computation graph is produced (see below)
cache – if True, the function stores the preprocessing done for a specific equation, the second call with the same equation is much faster
opset – ONNX opset to use for some runtimes
decompose – by default, the function decomposes the equation into more simple operators but it can keep the original ONNX einsum operator.
strategy – optimisation strategy (see below)
verbose – display progress if optimize is True
- Returns:
einsum result
The available runtimes are:
batch_dot: the runtime is
apply_einsum_sequence
,python: one ONNX graph executed with a python runtime,
onnxruntime: one ONNX graph executed with onnxruntime.
The optimisation strategy can be:
None: the same runtime is used to find the best permutation of letters
- ‘ml’: a machine learned model is used to predict the
best permutation of letters.
The function works in two steps:
first step analyses the equation to produce a computation graph, this graph can also be converted into ONNX,
second step runs the graph whatever the graph is.
Further details are available in the documentation of function
optimize_decompose_einsum_equation()
. The function works the same way asnumpy.einsum()
:<<<
import numpy from onnx_extended.tools.einsum import einsum equation = "abc,cd->abd" m1 = numpy.random.randn(2, 2, 2) m2 = numpy.random.randn(2, 2) np = numpy.einsum(equation, m1, m2) print("numpy.einsum") print(np) print("onnx_extended.tools.einsum") mp = einsum(equation, m1, m2) print(mp)
>>>
numpy.einsum [[[-2.306 2.062] [ 0.629 -0.82 ]] [[-1.915 1.989] [-2.625 2.313]]] onnx_extended.tools.einsum [[[-2.306 2.062] [ 0.629 -0.82 ]] [[-1.915 1.989] [-2.625 2.313]]]
In some case, the einsum implementation can be optimized by looping on possible permutation:
<<<
import timeit import numpy from onnx_extended.tools.einsum import einsum from onnx_extended.tools.einsum.einsum_fct import enumerate_cached_einsum equation = "cab,cd->ad" m1 = numpy.random.randn(20, 20, 20) m2 = numpy.random.randn(20, 20) print( "numpy.einsum", timeit.timeit("numpy.einsum(equation, m1, m2)", number=200, globals=globals()), ) einsum(equation, m1, m2) print( "einsum", timeit.timeit("einsum(equation, m1, m2)", number=200, globals=globals()) ) einsum(equation, m1, m2, runtime="python") print( "einsum-python", timeit.timeit( 'einsum(equation, m1, m2, runtime="python")', number=200, globals=globals() ), ) einsum(equation, m1, m2, runtime="onnxruntime") print( "einsum-onnxruntime", timeit.timeit( 'einsum(equation, m1, m2, runtime="onnxruntime")', number=200, globals=globals() ), ) einsum(equation, m1, m2, runtime="onnxruntime", optimize=True, verbose=1) print( "einsum-onnxruntime", timeit.timeit( 'einsum(equation, m1, m2, runtime="onnxruntime", optimize=True)', number=200, globals=globals(), ), ) print("list of cached einsum equations") for k, v in enumerate_cached_einsum(): print(k, v.equation, v.equation_)
>>>
[2024-05-08 13:59:35,532] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) numpy.einsum 0.038955299998633564 einsum 0.020066400000359863 einsum-python 0.09573180000006687 einsum-onnxruntime 0.013857800004188903 einsum-onnxruntime 0.013168299999961164 list of cached einsum equations ('cab,cd->ad', 'batch_dot', None, False, dtype('float64'), True, None) cab,cd->ad cab,cd->ad ('cab,cd->ad', 'python', None, False, dtype('float64'), True, None) cab,cd->ad cab,cd->ad ('cab,cd->ad', 'onnxruntime', None, False, dtype('float64'), True, None) cab,cd->ad cab,cd->ad ('cab,cd->ad', 'onnxruntime', None, True, dtype('float64'), True, None) cab,cd->ad dba,dc->bc [runpythonerror] 0%| | 0/25 [00:00<?, ?it/s] 0.0011 rtbest='cab,cd->ad': 0%| | 0/25 [00:00<?, ?it/s] 0.00059 rtbest='dab,dc->ac': 0%| | 0/25 [00:00<?, ?it/s] 0.00047 rtbest='bad,bc->ac': 0%| | 0/25 [00:00<?, ?it/s] 0.00043 rtbest='dba,dc->bc': 0%| | 0/25 [00:00<?, ?it/s] 0.00043 rtbest='dba,dc->bc': 40%|████ | 10/25 [00:00<00:00, 98.76it/s] 0.00043 rtbest='dba,dc->bc': 80%|████████ | 20/25 [00:00<00:00, 91.99it/s] 0.00043 rtbest='dba,dc->bc': 100%|██████████| 25/25 [00:00<00:00, 91.43it/s]
The last example shows the time taken by every function:
<<<
import logging import os import cProfile from io import StringIO from pstats import Stats import numpy from onnx_extended.tools.einsum import einsum from onnx_extended.tools.einsum.einsum_fct import enumerate_cached_einsum from onnx_extended import __file__ as path def profile(fct, sort="cumulative", **kwargs): pr = cProfile.Profile(**kwargs) pr.enable() fct_res = fct() pr.disable() s = StringIO() ps = Stats(pr, stream=s).sort_stats(sort) ps.print_stats() res = s.getvalue() return ps, res root = os.path.dirname(path) logging.getLogger("onnx-extended").setLevel(logging.ERROR) equation = "cab,cd->ad" m1 = numpy.random.randn(200, 20, 20) m2 = numpy.random.randn(200, 20) def clean(txt): txt = txt.replace(root, "onnx_extended") return "\n".join(txt.split("\n")[:30]) def fct1(): for i in range(100): einsum(equation, m1, m2, cache=False) print("Profile cache with default runtime.") res = profile(fct1) print(root) print(clean(res[1])) def fct2(): for i in range(100): einsum(equation, m1, m2, cache=False, runtime="python") print("Profile cache with runtime='python'.") res = profile(fct2) print(root) print(clean(res[1])) def fct3(): for i in range(100): einsum(equation, m1, m2, cache=True) einsum(equation, m1, m2, cache=True) print("Profile execution with default runtime.") res = profile(fct3) print(root) print(clean(res[1])) def fct4(): for i in range(100): einsum(equation, m1, m2, cache=True, runtime="python") einsum(equation, m1, m2, cache=True, runtime="python") print("Profile execution with runtime='python'.") res = profile(fct4) print(root) print(clean(res[1])) def fct5(): for i in range(100): einsum(equation, m1, m2, cache=True, runtime="onnxruntime") einsum(equation, m1, m2, cache=True, runtime="onnxruntime") print("Profile execution with runtime='onnxruntime'.") res = profile(fct5) print(root) print(clean(res[1]))
>>>
[2024-05-08 13:59:41,621] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) Profile cache with default runtime. /home/xadupre/github/onnx-extended/onnx_extended 128702 function calls (128502 primitive calls) in 0.465 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.001 0.001 0.465 0.465 <stdin>:45(fct1) 100 0.002 0.000 0.464 0.005 onnx_extended/tools/einsum/einsum_fct.py:691(einsum) 100 0.001 0.000 0.328 0.003 onnx_extended/tools/einsum/einsum_fct.py:599(optimize_decompose_einsum_equation) 100 0.001 0.000 0.328 0.003 onnx_extended/tools/einsum/einsum_fct.py:562(_einsum) 100 0.001 0.000 0.327 0.003 onnx_extended/tools/einsum/einsum_fct.py:532(build_einsum) 100 0.001 0.000 0.326 0.003 onnx_extended/tools/einsum/einsum_fct.py:305(build) 100 0.001 0.000 0.325 0.003 onnx_extended/tools/einsum/einsum_fct.py:475(build_runtime) 100 0.003 0.000 0.324 0.003 onnx_extended/tools/einsum/einsum_impl.py:79(decompose_einsum_equation) 100 0.041 0.000 0.246 0.002 onnx_extended/tools/einsum/einsum_impl.py:415(_decompose_einsum_equation_simple) 100 0.000 0.000 0.134 0.001 onnx_extended/tools/einsum/einsum_fct.py:525(__call__) 100 0.001 0.000 0.133 0.001 onnx_extended/tools/einsum/einsum_fct.py:500(<lambda>) 100 0.001 0.000 0.132 0.001 onnx_extended/tools/einsum/einsum_impl.py:163(apply_einsum_sequence) 100 0.005 0.000 0.132 0.001 onnx_extended/tools/einsum/einsum_impl_classes.py:1390(apply_sequence) 1200 0.007 0.000 0.125 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:717(apply) 1200 0.015 0.000 0.088 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:412(compute_output_row) 100 0.001 0.000 0.048 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1415(clean_unused_nodes) 200 0.028 0.000 0.047 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1422(iteration) 1900 0.039 0.000 0.039 0.000 {method 'reduce' of 'numpy.ufunc' objects} 100 0.008 0.000 0.039 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:596(_apply_batch_dot) 500 0.004 0.000 0.037 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:71(_wrapreduction) 4800 0.010 0.000 0.035 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:19(single_axes) 100 0.001 0.000 0.030 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:676(_apply_reduce_sum) 500 0.020 0.000 0.030 0.000 onnx_extended/tools/einsum/einsum_impl.py:232(_apply_transpose_reshape) 100 0.001 0.000 0.027 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2177(sum) 3800 0.025 0.000 0.025 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:32(<listcomp>) Profile cache with runtime='python'. /home/xadupre/github/onnx-extended/onnx_extended 437841 function calls (434145 primitive calls) in 2.725 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.001 0.001 2.725 2.725 <stdin>:56(fct2) 100 0.014 0.000 2.724 0.027 onnx_extended/tools/einsum/einsum_fct.py:691(einsum) 100 0.001 0.000 1.815 0.018 onnx_extended/tools/einsum/einsum_fct.py:599(optimize_decompose_einsum_equation) 100 0.001 0.000 1.815 0.018 onnx_extended/tools/einsum/einsum_fct.py:562(_einsum) 100 0.002 0.000 1.814 0.018 onnx_extended/tools/einsum/einsum_fct.py:532(build_einsum) 100 0.001 0.000 1.811 0.018 onnx_extended/tools/einsum/einsum_fct.py:305(build) 100 0.013 0.000 1.810 0.018 onnx_extended/tools/einsum/einsum_fct.py:475(build_runtime) 100 0.010 0.000 0.894 0.009 onnx_extended/tools/einsum/einsum_fct.py:525(__call__) 100 0.003 0.000 0.884 0.009 onnx_extended/tools/einsum/einsum_fct.py:512(<lambda>) 100 0.127 0.001 0.880 0.009 onnx_extended/reference/c_reference_evaluator.py:262(run) 100 0.038 0.000 0.649 0.006 onnx_extended/tools/einsum/einsum_impl_classes.py:1652(to_onnx) 100 0.014 0.000 0.629 0.006 onnx_extended/reference/c_reference_evaluator.py:231(__init__) 100 0.027 0.000 0.532 0.005 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:203(__init__) 2300 0.091 0.000 0.519 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:443(run) 100 0.004 0.000 0.505 0.005 onnx_extended/tools/einsum/einsum_impl.py:79(decompose_einsum_equation) 100 0.032 0.000 0.468 0.005 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:399(_init) 100 0.050 0.001 0.392 0.004 onnx_extended/tools/einsum/einsum_impl.py:415(_decompose_einsum_equation_simple) 5000 0.038 0.000 0.363 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1144(to_onnx) 2800 0.038 0.000 0.262 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:215(__init__) 3400 0.045 0.000 0.218 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:411(_check_and_fix_outputs) 2800 0.104 0.000 0.215 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:286(_load_attributes) 200/100 0.106 0.001 0.173 0.002 onnx_extended/tools/onnx_nodes.py:112(onnx_remove_node_unused) 100 0.006 0.000 0.166 0.002 onnx_extended/tools/onnx_nodes.py:60(_apply_optimisation_on_graph) 1200 0.023 0.000 0.150 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:412(compute_output_row) 2800 0.037 0.000 0.134 0.000 /home/xadupre/github/onnx/onnx/helper.py:132(make_node) Profile execution with default runtime. /home/xadupre/github/onnx-extended/onnx_extended 32202 function calls in 0.411 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.001 0.001 0.411 0.411 <stdin>:67(fct3) 100 0.003 0.000 0.410 0.004 onnx_extended/tools/einsum/einsum_fct.py:691(einsum) 100 0.001 0.000 0.403 0.004 onnx_extended/tools/einsum/einsum_fct.py:525(__call__) 100 0.001 0.000 0.400 0.004 onnx_extended/tools/einsum/einsum_fct.py:500(<lambda>) 100 0.002 0.000 0.399 0.004 onnx_extended/tools/einsum/einsum_impl.py:163(apply_einsum_sequence) 100 0.010 0.000 0.398 0.004 onnx_extended/tools/einsum/einsum_impl_classes.py:1390(apply_sequence) 1200 0.030 0.000 0.384 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:717(apply) 100 0.033 0.000 0.133 0.001 onnx_extended/tools/einsum/einsum_impl_classes.py:596(_apply_batch_dot) 500 0.029 0.000 0.131 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:71(_wrapreduction) 100 0.011 0.000 0.102 0.001 onnx_extended/tools/einsum/einsum_impl_classes.py:676(_apply_reduce_sum) 500 0.094 0.000 0.094 0.000 {method 'reduce' of 'numpy.ufunc' objects} 100 0.002 0.000 0.087 0.001 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2177(sum) 400 0.009 0.000 0.055 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2979(prod) 200 0.006 0.000 0.044 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:503(_apply_expand_dims) 400 0.010 0.000 0.043 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:516(_apply_transpose) 300 0.009 0.000 0.034 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/lib/shape_base.py:512(expand_dims) 100 0.020 0.000 0.020 0.000 onnx_extended/tools/einsum/blas_lapack.py:76(gemm_dot) 2000 0.019 0.000 0.019 0.000 {built-in method builtins.getattr} 1300 0.013 0.000 0.019 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:463(_get_data) 400 0.002 0.000 0.018 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:588(transpose) 300 0.014 0.000 0.017 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/numeric.py:1330(normalize_axis_tuple) 100 0.003 0.000 0.016 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:703(_apply_squeeze) 400 0.002 0.000 0.016 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:53(_wrapfunc) 100 0.010 0.000 0.012 0.000 onnx_extended/tools/einsum/einsum_impl_classes.py:1193(get_dot_kind) 200 0.010 0.000 0.011 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:1491(squeeze) Profile execution with runtime='python'. /home/xadupre/github/onnx-extended/onnx_extended 145502 function calls in 1.575 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.001 0.001 1.575 1.575 <stdin>:79(fct4) 100 0.003 0.000 1.573 0.016 onnx_extended/tools/einsum/einsum_fct.py:691(einsum) 100 0.001 0.000 1.568 0.016 onnx_extended/tools/einsum/einsum_fct.py:525(__call__) 100 0.005 0.000 1.567 0.016 onnx_extended/tools/einsum/einsum_fct.py:512(<lambda>) 100 0.224 0.002 1.562 0.016 onnx_extended/reference/c_reference_evaluator.py:262(run) 2300 0.169 0.000 0.922 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:443(run) 3400 0.090 0.000 0.366 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:411(_check_and_fix_outputs) 400 0.007 0.000 0.213 0.001 /home/xadupre/github/onnx/onnx/reference/ops/_op.py:43(run) 400 0.007 0.000 0.153 0.000 /home/xadupre/github/onnx/onnx/reference/ops/_op.py:20(run) 6800 0.020 0.000 0.145 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:428(<genexpr>) 6800 0.040 0.000 0.126 0.000 {built-in method builtins.any} 3400 0.053 0.000 0.125 0.000 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/numeric.py:1855(isscalar) 100 0.002 0.000 0.112 0.001 /home/xadupre/github/onnx/onnx/reference/ops/op_reduce_sum.py:22(_run) 20900 0.055 0.000 0.109 0.000 {built-in method builtins.isinstance} 5600 0.045 0.000 0.107 0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:244(_log) 100 0.002 0.000 0.087 0.001 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2177(sum) 100 0.002 0.000 0.086 0.001 /home/xadupre/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:71(_wrapreduction) 300 0.002 0.000 0.079 0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_reshape.py:34(_run) 300 0.041 0.000 0.077 0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_reshape.py:11(reshape_reference_implementation) 100 0.074 0.001 0.074 0.001 {method 'reduce' of 'numpy.ufunc' objects} 400 0.002 0.000 0.074 0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_identity.py:10(_run) 400 0.072 0.000 0.072 0.000 {method 'copy' of 'numpy.ndarray' objects} 5600 0.042 0.000 0.061 0.000 /home/xadupre/github/onnx/onnx/reference/reference_evaluator.py:410(<lambda>) 200 0.009 0.000 0.059 0.000 /home/xadupre/github/onnx/onnx/reference/ops/op_unsqueeze.py:35(_run) 3400 0.011 0.000 0.053 0.000 /usr/lib/python3.10/abc.py:117(__instancecheck__) Profile execution with runtime='onnxruntime'. /home/xadupre/github/onnx-extended/onnx_extended 1602 function calls in 0.229 seconds Ordered by: cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 1 0.001 0.001 0.229 0.229 <stdin>:91(fct5) 100 0.002 0.000 0.228 0.002 onnx_extended/tools/einsum/einsum_fct.py:691(einsum) 100 0.000 0.000 0.224 0.002 onnx_extended/tools/einsum/einsum_fct.py:525(__call__) 100 0.001 0.000 0.223 0.002 onnx_extended/tools/einsum/einsum_fct.py:512(<lambda>) 100 0.221 0.002 0.222 0.002 /home/xadupre/github/onnxruntime/build/linux_cuda/Release/onnxruntime/capi/onnxruntime_inference_collection.py:202(run) 100 0.000 0.000 0.001 0.000 onnx_extended/tools/einsum/einsum_fct.py:599(optimize_decompose_einsum_equation) 100 0.001 0.000 0.001 0.000 /home/xadupre/github/onnxruntime/build/linux_cuda/Release/onnxruntime/capi/onnxruntime_inference_collection.py:192(_validate_input) 100 0.000 0.000 0.001 0.000 onnx_extended/tools/einsum/einsum_fct.py:562(_einsum) 300 0.001 0.000 0.001 0.000 onnx_extended/tools/einsum/einsum_fct.py:931(<genexpr>) 100 0.000 0.000 0.000 0.000 onnx_extended/tools/einsum/einsum_fct.py:513(<dictcomp>) 100 0.000 0.000 0.000 0.000 /home/xadupre/github/onnxruntime/build/linux_cuda/Release/onnxruntime/capi/onnxruntime_inference_collection.py:218(<listcomp>) 100 0.000 0.000 0.000 0.000 {method 'get' of 'dict' objects} 100 0.000 0.000 0.000 0.000 {built-in method builtins.len} 100 0.000 0.000 0.000 0.000 {built-in method builtins.hasattr} 100 0.000 0.000 0.000 0.000 {method 'keys' of 'dict' objects} 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
einsum_benchmark¶
- onnx_extended.tools.einsum.einsum_bench.einsum_benchmark(equation: str = 'abc,cd->abd', shape: int = 30, perm: bool = False, runtime: str = 'python', use_tqdm: bool = False, number: int = 5, repeat: int = 5, opset=18) Iterable[Dict[str, str | float]] [source]¶
Investigates whether or not the decomposing einsum is faster.
- Parameters:
equation – einsum equation to test
shape – an integer (all dimension gets the same size) or a list of shapes in a string separated with ;)
perm – check on permutation or all letter permutations
runtime – a string among ‘numpy’, ‘python’, ‘onnxruntime’
use_tqdm – show progress
number – usual parameter to measure a function
repeat – usual parameter to measure a function
opset – target opset
- Returns:
list of dictionaries as an iterator
numpy_extended_dot¶
- onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot(m1: ndarray, m2: ndarray, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...], verbose: bool = False) ndarray [source]¶
Extended version of a matrix multiplication (
numpy.dot()
) with two matrices m1, m2 of the same dimensions. Loops over left axes for m1 and right axes for m2, summation is done over axes. Other axes must be empty. This multiplication combines matrix multiplication (dot) and broadcasted multiplication term by term.- Parameters:
m1 – first matrix
m2 – second matrix
axes – summation axes
left – left axes
right – right axes
verbose – display intermediate information
- Returns:
output
The dot product is equivalent to:
<<<
import numpy from onnx_extended.tools.einsum import numpy_extended_dot m1 = numpy.arange(4).reshape((2, 2)) m2 = m1 + 10 print("dot product") print(m1 @ m2) dm1 = m1.reshape((2, 2, 1)) dm2 = m2.reshape((1, 2, 2)) dot = numpy_extended_dot(dm1, dm2, axes=[1], left=[0], right=[2], verbose=True) print("extended dot product") print(dot)
>>>
dot product [[12 13] [56 61]] [numpy_extended_dot] Abc,abC->AC: (2, 2, 1) @ (1, 2, 2) [numpy_extended_dot] (2, 2) reshaped into [2, 1, 2] extended dot product [[[12 13]] [[56 61]]]
Empty axes should be squeezed to get identical results. Dot product when the second matrix is transposed.
<<<
import numpy from onnx_extended.tools.einsum import numpy_extended_dot m1 = numpy.arange(4).reshape((2, 2)) m2 = m1 + 10 print("dot product") print(m1 @ m2.T) dm1 = m1.reshape((2, 1, 2)) dm2 = m2.reshape((1, 2, 2)) dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1], verbose=True) print("extended dot product") print(dot)
>>>
dot product [[11 13] [53 63]] [numpy_extended_dot] Abc,aBc->AB: (2, 1, 2) @ (1, 2, 2) [numpy_extended_dot] (2, 2) reshaped into [2, 2, 1] extended dot product [[[11] [13]] [[53] [63]]]
An example when right axes include the summation axis.
<<<
import numpy from onnx_extended.tools.einsum import numpy_extended_dot m1 = numpy.arange(4).reshape((2, 2)) m2 = m1 + 10 dm1 = m1.reshape((2, 2, 1)) dm2 = m2.reshape((1, 2, 2)) dot = numpy_extended_dot(dm1, dm2, axes=[2], left=[0], right=[1, 2], verbose=True) print(dot)
>>>
[numpy_extended_dot] Abc,aBc->ABc: (2, 2, 1) @ (1, 2, 2) [numpy_extended_dot] (2, 2, 2) reshaped into [2, 2, 2] [[[10 11] [12 13]] [[50 55] [60 65]]]
Example in higher dimension:
<<<
import numpy from onnx_extended.tools.einsum import numpy_extended_dot m1 = numpy.arange(8).reshape((2, 2, 2)) m2 = m1 + 10 dot = numpy_extended_dot(m1, m2, [1], [0], [2], verbose=True) print(dot)
>>>
[numpy_extended_dot] Abc,abC->AC: (2, 2, 2) @ (2, 2, 2) [numpy_extended_dot] (2, 2) reshaped into [2, 1, 2] [[[164 176]] [[580 624]]]
The current implementation still uses
numpy.einsum()
but this should be replaced.
numpy_extended_dot_matrix¶
- onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_matrix(m1: ndarray, m2: ndarray, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...], verbose: bool = False) ndarray [source]¶
Implementation of
numpy_extended_dot
using dot product, multiplication, transpose and reduction but not a custom python implementation likenumpy_extended_dot_python
.<<<
import numpy from onnx_extended.tools.einsum import numpy_extended_dot_matrix from onnx_extended.tools.einsum.einsum_impl_ext import _numpy_extended_dot_equation a = numpy.arange(6).reshape((3, 2, 1)) b = numpy.arange(12).reshape((3, 1, 4)) print(numpy_extended_dot_matrix(a, b, axes=(0,), left=(1,), right=(2,))) # Equivalent einsum equation print( "equation", _numpy_extended_dot_equation( len(a.shape), len(a.shape), axes=(0,), left=(1,), right=(2,) ), ) # Same einsum computation written in a different way. print(numpy.einsum("kix,kxj->xij", a, b))
>>>
[[[40 46 52 58] [52 61 70 79]]] equation aBc,abC->BC [[[40 46 52 58] [52 61 70 79]]]
numpy_extended_dot_python¶
- onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_python(m1: ndarray, m2: ndarray, axes: Tuple[int, ...], left: Tuple[int, ...], right: Tuple[int, ...], verbose: bool = False) ndarray [source]¶
Implementation of
numpy_extended_dot
in pure python. This implementation is not efficient but shows how to implement this operation withoutnumpy.einsum()
.<<<
import numpy from onnx_extended.tools.einsum import numpy_extended_dot_python from onnx_extended.tools.einsum.einsum_impl_ext import _numpy_extended_dot_equation a = numpy.arange(6).reshape((3, 2, 1)) b = numpy.arange(12).reshape((3, 1, 4)) print(numpy_extended_dot_python(a, b, axes=(0,), left=(1,), right=(2,))) # Equivalent einsum equation print( "equation", _numpy_extended_dot_equation( len(a.shape), len(a.shape), axes=(0,), left=(1,), right=(2,) ), ) # Same einsum computation written in a different way. print(numpy.einsum("kix,kxj->xij", a, b))
>>>
[[[40 46 52 58] [52 61 70 79]]] equation aBc,abC->BC [[[40 46 52 58] [52 61 70 79]]]
EinsumSubOp¶
- class onnx_extended.tools.einsum.einsum_impl_classes.EinsumSubOp(full_dim: int, name: str, *inputs: List[EinsumSubOp], **kwargs: Dict[str, Any])[source]¶
Defines a sub operation used in Einsum decomposition.
- Parameters:
full_dim – dimension of the result
name – name (reshape, transpose, reduce_sum, matmul, id, squeeze, diagonal, mul, batch_dot)
inputs – inputs
kwargs – arguments
Operator suffixed by _mm (transpose_mm, reduce_sum_mm) are equivalent to the same operator without the suffix but takes two inputs and only changes the first one.
Attributes _info summarizes the known information about dimensions. Many of them are empty because inserted. Value 1 means it was the case, 2 means it is a plain dimension.
- add_info(**kwargs: Dict[str, Any])[source]¶
Adds information to the node.
- Parameters:
kwargs – dictionary
- apply(data: Dict[int, Any], verbose: bool = False, **kwargs: Dict[str, Any]) Any [source]¶
Applies one operator on the data.
- Parameters:
data – dictionary storing the results
verbose – prints out intermediate results
kwargs – additional parameters, see methods _apply*
- Returns:
output
Known additional paramaters:
‘matmul_impl’: if None calls
numpy.einsum()
throughnumpy_extended_dot
(default) or ‘py’ to callnumpy_extended_dot_python
instead.
- compute_output_row(row: ndarray, row2: ndarray | None = None, ab: bool = False, verbose: bool = False)[source]¶
Updates row based on the operator.
- get_dot_kind() str [source]¶
Every matrix multiplication can be either:
a simple multiplication (M) (undetected)
a 2D matrix multiplication (11)
a broadcasted matrix multiplication (N1 or 1N)
a batch matrix multiplication (NN)
This method returns which kind it is.
- to_onnx(names: List[str], opset: int | None, verbose: bool = False, **kwargs: Dict[str, Any]) Iterable[NodeProto] [source]¶
Converts this node into ONNX. Enumerates all ONNX node which participate to the conversion. The last one is the final output.
- Parameters:
names – dictionary where to find already converted name
opset – opset
verbose – prints out intermediate results
kwargs – additional parameter for the conversion
- Returns:
output
GraphEinsumSubOp¶
- class onnx_extended.tools.einsum.einsum_impl_classes.GraphEinsumSubOp(letters: str, mat: ndarray, lengths: List[int], duplicates: List[Dict[str, int]])[source]¶
Class gathering all nodes produced to explicit einsum operators.
- Parameters:
letters – list of distinct letters
mat – matrix, see
analyse_einsum_equation
lengths – lengths of every input
duplicates – see
analyse_einsum_equation
- append(op: int | EinsumSubOp) EinsumSubOp | None [source]¶
Adds one input or result.
- Parameters:
op – integer (an input) or an instance of
EinsumSubOp
.- Returns:
op or None if op is an integer
- apply_sequence(*inputs: List[EinsumSubOp], verbose: bool = False, **kwargs: Dict[str, Any]) Any [source]¶
Applies a sequence of operations on a list of inputs.
- Parameters:
inputs – inputs
verbose – prints out intermediate results
kwargs – additional parameters, see
apply
.
- Returns:
output
- clean_unused_nodes(verbose: bool = False)[source]¶
Cleans nodes with unused outputs.
- Parameters:
verbose – display intermediate information
- mark(i: int, op: EinsumSubOp)[source]¶
Marks one input or result as an intermediate result after a full einsum step.
- Parameters:
i – a position
op – an instance of
EinsumSubOp
.
- remove_duplicate_transpose(verbose: bool = False)[source]¶
Removes consecutive transpose by merging them.
- Parameters:
verbose – display intermediate information
- simplify_mm_nodes(verbose: bool = False)[source]¶
Node name suffixed by mm are an artifact to keep the graph consistent while building it. They can now be replaced by the equivalent node without suffix mm.
- Parameters:
verbose – display intermediate information
- to_dot(**kwargs: Dict[str, Any]) str [source]¶
Produces a graph in dot.
- Parameters:
kwargs – additional graph option
- Returns:
string
- to_onnx(output: str, *inputs: List[str], dtype: Any | None = None, verbose: bool = False, opset: int | None = None, **kwargs: Dict[str, Any]) ModelProto [source]¶
Converts the graph into ONNX.
- Parameters:
output – output name
inputs – input names
dtype – type used for all operators
opset – desired opset, None for the last one
verbose – display intermediate operators
kwargs – additional parameter to use when building the ONNX graph, list of supported parameters: name, ir_version, producer_name, producer_version, initializer
- Returns:
ONNX graph
Not all graphs can be converted into ONNX. Only graphs produced with strategy=’numpy’ can be converted otherwise the following error shows up:
NotImplementedError: to_onnx not implemented for 'matmul'.
OnnxMicroRuntime¶
optimize_decompose_einsum_equation¶
- onnx_extended.tools.einsum.einsum_fct.optimize_decompose_einsum_equation(equation: str, dtype: Any, optimize: bool = False, runtime: str = 'batch_dot', cache: bool = True, opset: int | None = None, decompose: bool = True, strategy: str | None = None, verbose: bool | None = None) CachedEinsum [source]¶
Proposes a new implementation of
numpy.einsum()
. It does not allow expresion using … and expects a right member.- Parameters:
equation – einsum equation
dtype – numpy dtype used for the computation
optimize – permutes all letters to find the best permutation
runtime – runtime used to compute the results once the computation graph is produced (see below)
cache – if True, the function stores the preprocessing done for a specific equation, the second call with the same equation is much faster
opset – ONNX opset to use for some runtimes
decompose – by default, the function decomposes the equation into more simple operators but it can keep the original ONNX einsum operator.
strategy – optimisation strategy (see below)
verbose – display progress if optimize is True
- Returns:
einsum result
The available runtimes are:
batch_dot: the runtime is
apply_einsum_sequence
,python: one ONNX graph executed with a python runtime,
onnxruntime: one ONNX graph executed with onnxruntime.
The optimisation strategy can be:
None: the same runtime is used to find the best permutation of letters
- ‘ml’: a machine learned model is used to predict the
best permutation of letters.
The function works in two steps:
first step analyses the equation to produce a computation graph, this graph can also be converted into ONNX,
second step runs the graph whatever the graph is.
The function returns an object of type
CachedEinsum
which has the following members after optimization:equation_ corresponding to the best equivalent equation
- graph_: the corresponding graph returned by function
onnx_: if a conversion to onnx is used, stores the onnx graph
runtime_: a function used by __call__, calls the runtime
oinf_: an object of type
CReferenceEvaluator
timed_permutations_: memorizes the results of the optimization
<<<
import numpy from onnx_extended.tools.einsum import optimize_decompose_einsum_equation seq_opt = optimize_decompose_einsum_equation( "bsnh,btnh->bnts", numpy.float64, strategy="ml", verbose=1, runtime="python", optimize=True, ) print("best equation:", seq_opt.equation_)
>>>
0%| | 0/121 [00:00<?, ?it/s] 4.5 mlbest='bsnh,btnh->bnts': 0%| | 0/121 [00:00<?, ?it/s] 4.5 mlbest='bsnh,btnh->bnts': 3%|3 | 4/121 [00:00<00:02, 39.46it/s] 4.5 mlbest='bnth,bsth->btsn': 3%|3 | 4/121 [00:00<00:02, 39.46it/s] 4.5 mlbest='bnth,bsth->btsn': 8%|8 | 10/121 [00:00<00:02, 48.67it/s] 4.5 mlbest='bnht,bsht->bhsn': 8%|8 | 10/121 [00:00<00:02, 48.67it/s] 4.5 mlbest='bhtn,bstn->btsh': 8%|8 | 10/121 [00:00<00:02, 48.67it/s] 4.5 mlbest='bhtn,bstn->btsh': 12%|#2 | 15/121 [00:00<00:02, 41.23it/s] 4.5 mlbest='bhts,bnts->btnh': 12%|#2 | 15/121 [00:00<00:02, 41.23it/s] 4.5 mlbest='bhts,bnts->btnh': 17%|#6 | 20/121 [00:00<00:02, 40.72it/s] 4.5 mlbest='bhts,bnts->btnh': 21%|## | 25/121 [00:00<00:02, 38.44it/s] 4.5 mlbest='bhts,bnts->btnh': 25%|##4 | 30/121 [00:00<00:02, 40.95it/s] 4.5 mlbest='bhts,bnts->btnh': 29%|##8 | 35/121 [00:00<00:02, 37.23it/s] 4.5 mlbest='bhts,bnts->btnh': 32%|###2 | 39/121 [00:01<00:02, 36.24it/s] 4.5 mlbest='bhts,bnts->btnh': 36%|###6 | 44/121 [00:01<00:01, 38.83it/s] 4.5 mlbest='bhts,bnts->btnh': 40%|###9 | 48/121 [00:01<00:01, 38.98it/s] 4.5 mlbest='bhts,bnts->btnh': 43%|####2 | 52/121 [00:01<00:01, 36.54it/s] 4.5 mlbest='bhts,bnts->btnh': 46%|####6 | 56/121 [00:01<00:01, 34.97it/s] 4.5 mlbest='bhts,bnts->btnh': 50%|####9 | 60/121 [00:01<00:01, 34.44it/s] 4.5 mlbest='bhts,bnts->btnh': 54%|#####3 | 65/121 [00:01<00:01, 37.73it/s] 4.5 mlbest='bhts,bnts->btnh': 61%|######1 | 74/121 [00:01<00:00, 50.83it/s] 4.5 mlbest='bhts,bnts->btnh': 68%|######7 | 82/121 [00:01<00:00, 56.69it/s] 4.5 mlbest='bhts,bnts->btnh': 74%|#######4 | 90/121 [00:02<00:00, 62.13it/s] 4.5 mlbest='bhts,bnts->btnh': 82%|########1 | 99/121 [00:02<00:00, 68.59it/s] 4.5 mlbest='bhts,bnts->btnh': 88%|########7 | 106/121 [00:02<00:00, 63.77it/s] 4.5 mlbest='bhts,bnts->btnh': 93%|#########3| 113/121 [00:02<00:00, 47.67it/s] 4.5 mlbest='bhts,bnts->btnh': 98%|#########8| 119/121 [00:02<00:00, 49.98it/s] 4.5 mlbest='bhts,bnts->btnh': 100%|##########| 121/121 [00:02<00:00, 45.81it/s] best equation: bhts,bnts->btnh
predict_transposition_cost¶
- onnx_extended.tools.einsum.einsum_ml.predict_transposition_cost(shape: Tuple[int, ...], perm: Tuple[int, ...], coefs: Dict[str, float] | None = None) float [source]¶
Given a shape and a permutation, predicts the cost of the transposition.
- Parameters:
shape – shape
perm – permutation
coefs – trained coefficients or None to get the default ones
- Returns:
dictionary of features
<<<
import pprint from onnx_extended.tools.einsum.einsum_ml import compute_transposition_features pprint.pprint(compute_transposition_features((3, 5, 7), (2, 1, 0)))
>>>
{'CST_': -1, 'begin': -1, 'dbegin': 0, 'dend': 0, 'dim': 3, 'discont': 2, 'edit': 2, 'end': -1, 'end16': -16, 'end32': -32, 'ibegin16': -0.0, 'ibegin2': -0.0, 'ibegin32': -0.0, 'ibegin4': -0.0, 'ibegin64': -0.0, 'ibegin8': -0.0, 'iend16': -0.0, 'iend2': -0.0, 'iend32': -0.0, 'iend4': -0.0, 'iend64': -0.0, 'iend8': -0.0, 'middle': 105, 'rbegin': -0.009523809523809525, 'rdiscont': -0.01904761904761905, 'redit': 0.6666666666666666, 'rend': -0.009523809523809525, 'rend16': -0.1523809523809524, 'rend32': -0.3047619047619048, 'rev': 1, 'rmiddle': 1.0, 'rot': 0, 'size': 105}