Source code for onnx_extended.tools.einsum.einsum_bench

from itertools import permutations
from typing import Any, Dict, Iterable, List, Optional, Union
import numpy
from onnx import helper, ModelProto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from ...ext_test_case import measure_time
from .einsum_config import DEFAULT_OPSET, DEFAULT_IR_VERSION
from .einsum_impl import decompose_einsum_equation, apply_einsum_sequence


def _measure_time(
    stmt: Any,
    *x: List[numpy.ndarray],
    repeat: int = 5,
    number: int = 5,
    div_by_number: bool = True,
    first_run: bool = True,
    max_time: Optional[float] = None,
) -> Dict[str, Union[str, float]]:
    """
    Measures a statement and returns the results as a dictionary.

    :param stmt: string
    :param *x: inputs
    :param repeat: average over *repeat* experiment
    :param number: number of executions in one row
    :param div_by_number: divide by the number of executions
    :param first_run: if True, runs the function once before measuring
    :param max_time: execute the statement until the total goes
        beyond this time (approximatively), *repeat* is ignored,
        *div_by_number* must be set to True
    :return: dictionary

    See `Timer.repeat
    <https://docs.python.org/3/library/timeit.html?timeit.Timer.repeat>`_
    for a better understanding of parameter *repeat* and *number*.
    The function returns a duration corresponding to
    *number* times the execution of the main statement.
    """
    if first_run:
        try:
            stmt(*x)
        except RuntimeError as e:
            raise RuntimeError(f"{type(x)}-{getattr(x, 'dtype', '?')}") from e

    def fct():
        stmt(*x)

    if first_run:
        fct()

    return measure_time(
        fct,
        context={},
        repeat=repeat,
        number=number,
        div_by_number=div_by_number,
        max_time=max_time,
    )


def _make_einsum_model(equation: str, opset: int = DEFAULT_OPSET) -> ModelProto:
    inputs = equation.split("->")[0].split(",")

    model = helper.make_model(
        opset_imports=[helper.make_operatorsetid("", opset)],
        ir_version=DEFAULT_IR_VERSION,
        producer_name="onnx_extended",
        producer_version="0.1",
        graph=helper.make_graph(
            name="einsum_test",
            inputs=[
                helper.make_tensor_value_info("X%d" % i, TensorProto.FLOAT, None)
                for i in range(len(inputs))
            ],
            outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
            nodes=[
                helper.make_node(
                    "Einsum",
                    ["X%d" % i for i in range(len(inputs))],
                    ["Y"],
                    equation=equation,
                )
            ],
        ),
    )
    return model


def _make_inputs(equation, shapes):
    inputs = equation.split("->")[0].split(",")
    dims = [len(i) for i in inputs]

    if isinstance(shapes, int):
        N = shapes
        shapes = [(N,) * le for le in dims]
    else:
        assert len(shapes) == len(
            inputs
        ), f"Unexpected number of shapes {shapes!r} with equation {equation!r}."
    inputs = [numpy.random.randn(*sh) for sh in shapes]
    return [i.astype(numpy.float32) for i in inputs]


[docs]def einsum_benchmark( equation: str = "abc,cd->abd", shape: int = 30, perm: bool = False, runtime: str = "python", use_tqdm: bool = False, number: int = 5, repeat: int = 5, opset=DEFAULT_OPSET, ) -> Iterable[Dict[str, Union[str, float]]]: """ Investigates whether or not the decomposing einsum is faster. :param equation: einsum equation to test :param shape: an integer (all dimension gets the same size) or a list of shapes in a string separated with `;`) :param perm: check on permutation or all letter permutations :param runtime: a string among 'numpy', 'python', 'onnxruntime' :param use_tqdm: show progress :param number: usual parameter to measure a function :param repeat: usual parameter to measure a function :param opset: target opset :return: list of dictionaries as an iterator """ scenarios = [] if isinstance(shape, list) and all(map(lambda t: isinstance(t, int), shape)): shape_list = shape else: shape_list = [shape] if perm: assert equation.lower() == equation, ( "Only equations with lower letters are allowed but equation %r " "is not." % equation ) letters = list( sorted(set(c for c in equation if "a" <= c < "z" or "A" <= c < "Z")) ) for p in permutations(letters): replace = {d: c for c, d in zip(letters, p)} eq = equation for k, v in replace.items(): eq = eq.replace(k, v.upper()) eq = eq.lower() for dec in ["einsum", "dec"]: for sh in shape_list: scenarios.append((eq, runtime, dec, sh)) else: for dec in ["einsum", "dec"]: for sh in shape_list: scenarios.append((equation, runtime, dec, sh)) if use_tqdm: from tqdm import tqdm loop = tqdm(scenarios) else: loop = scenarios for eq, rt, dec, sh in loop: inputs = _make_inputs(equation, sh) if dec == "dec": seq = decompose_einsum_equation(eq, strategy="numpy", clean=True) else: seq = None if rt == "numpy": if dec == "einsum": fct = lambda *x, eq=eq: numpy.einsum(eq, *x, optimize=True) else: fct = lambda *x, seq=seq: apply_einsum_sequence(seq, *x) elif rt == "onnxruntime": if dec == "einsum": onx = _make_einsum_model(equation, opset=opset) else: assert seq is not None, "seq cannot be None." onx = seq.to_onnx( "Y", *["X%d" % i for i in range(len(inputs))], opset=opset ) sess = InferenceSession( onx.SerializeToString(), providers=["CPUExecutionProvider"] ) fct = lambda *x, se=sess: se.run( None, {"X%d" % i: v for i, v in enumerate(x)} ) elif rt == "python": if dec == "einsum": onx = _make_einsum_model(equation, opset=opset) else: assert seq is not None, "seq must not be None." onx = seq.to_onnx( "Y", *["X%d" % i for i in range(len(inputs))], opset=opset ) oinf = ReferenceEvaluator(onx) fct = lambda *x, oi=oinf: oi.run( None, {"X%d" % i: v for i, v in enumerate(x)} ) else: raise ValueError(f"Unexpected runtime {rt!r}.") res = _measure_time(fct, *inputs, repeat=repeat, number=number) res["rt"] = rt res["dec"] = dec res["eq"] = eq res["shapes"] = ";".join(map(str, [m.shape for m in inputs])).replace(" ", "") yield res