Source code for onnx_extended.tools.einsum.einsum_impl

from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy
from .einsum_impl_classes import EinsumSubOp, GraphEinsumSubOp


[docs]def analyse_einsum_equation( equation: str, ) -> Tuple[str, numpy.ndarray, List[int], List[Optional[Dict[str, List[int]]]]]: """ Analyses an einsum equation. :param equation: :func:`numpy.einsum` equation :return: four results, list of letters, a matrix (see below), lengths of each components, duplicates The returned a matrix is defined as follows: .. math:: m_{ij}=\\left\\{\\begin{array}{ll}-1 & \\text{if letter j is involved in input i} \\\\ p & \\text{p is position of letter j in equation i} \\end{array}\\right. """ spl = equation.strip(" ,").split("->") if len(spl) != 2 or not spl[1] or not spl[0]: raise NotImplementedError( "The function only implements the case when there are " "two sides in the equation: %r." % equation ) inputs = list(map(lambda s: s.strip(), spl[0].split(","))) output = spl[1] all_letters = set(inputs[0]) # Set of letters for inp in inputs[1:]: all_letters |= set(inp) letters = list(sorted(all_letters)) for c in letters: if not (("a" <= c <= "z") or ("A" <= c <= "Z")): raise ValueError( "Equation %r must only contain lower or upper letters " "but %r is not." % (equation, c) ) rev = {c: i for i, c in enumerate(letters)} for c in output: assert ( c in letters ), f"Output contains one unexpected letter {c!r} in equation {equation!r}, letters={letters!r}." mat = numpy.full((len(inputs) + 1, len(letters)), -1, dtype=numpy.int8) for i, inp in enumerate(inputs): for k, c in enumerate(inp): mat[i, rev[c]] = k for k, c in enumerate(output): mat[len(inputs), rev[c]] = k lengths = [len(inp) for inp in inputs] lengths.append(len(output)) # Look for duplicates duplicates: List[Optional[Dict[str, List[int]]]] = [] for inp in inputs + [output]: if len(inp) == len(set(inp)): duplicates.append(None) continue # There is some duplicates. counts: Dict[str, List[int]] = {} for i, c in enumerate(inp): if c in counts: counts[c].append(i) else: counts[c] = [i] duplicates.append(counts) return "".join(letters), mat, lengths, duplicates
[docs]def decompose_einsum_equation( equation: str, *shapes: List[Tuple[int, ...]], strategy: str = "simple", clean: bool = False, verbose: bool = False, ) -> GraphEinsumSubOp: """ Decomposes an equation used in :func:`numpy.einsum` knowing the input shapes. It returns a sequence of operations to do to compute the results. :param equation: a string :param shapes: sequence of input shapes :param strategy: there are different way to decompose the equation, this parameters defines the way to do it (see below) :param clean: clean the unnecessary node in the graph :param verbose: verbosity :return: instance of :class:`GraphEinsumSubOp <onnx_extended.tools.einsum.einsum_impl_classes.GraphEinsumSubOp>` About *strategy*: * `'simple'`: align all dimensions in the alphabetical order, some generic matrix multiplication remains implemented with :func:`numpy.einsum` but only with two matrices aligned on the same dimension (see :func:`numpy_extended_dot <onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot>`) * `'numpy'`: same as `simple` but the decomposition does not use :func:`numpy.einsum` anymore but only multiplication or matrix multiplication merged into a single operator called *batch_dot* (see :func:`numpy_extended_dot_matrix <onnx_extended.tools.einsum.einsum_impl_ext.numpy_extended_dot_matrix>`) Available operations: *expand_dims*, *transpose*, *matmul*, *reduce_sum*, *id*, *squeeze*, *diagonal*. It analyses an equation and produces a graph where nodes are instance of class :class:`EinsumSubOp <onnx_extended.tools.einsum.einsum_impl_classes.EinsumSubOp>`. .. runpython:: :showcode: from onnx_extended.tools.einsum import decompose_einsum_equation seq = decompose_einsum_equation("bac,cd,def->ebc") for op in seq: print(op) It can be better displayed as the following. .. gdot:: :script: DOT-SECTION :process: from onnx_extended.tools.einsum import decompose_einsum_equation seq = decompose_einsum_equation( "bac,cd,def->ebc", (2, 2, 2), (2, 2), (2, 2, 2)) print("DOT-SECTION", seq.to_dot()) """ if shapes: for sh in shapes: if not isinstance(sh, tuple): raise TypeError(f"All shapes must be tuples for {sh!r} is not.") if strategy in ("simple", "numpy"): op_matmul = {"simple": "matmul", "numpy": "batch_dot"} graph = _decompose_einsum_equation_simple( equation, *shapes, verbose=verbose, op_matmul=op_matmul[strategy] ) else: raise ValueError(f"Unknown strategy {strategy!r}.") # Last step: clean unused nodes. if clean: last_node = graph.last_added_op assert isinstance(last_node, EinsumSubOp) graph.append(EinsumSubOp(last_node.full_dim, "id", last_node)) graph.mark_last_node() graph.simplify_mm_nodes(verbose=verbose) graph.remove_duplicate_transpose(verbose=verbose) graph.clean_unused_nodes(verbose=verbose) else: graph.mark_last_node() return graph
[docs]def apply_einsum_sequence( seq: List[numpy.ndarray], *inputs: List[EinsumSubOp], verbose: bool = False, **kwargs: Dict[str, Any], ) -> numpy.ndarray: """ Applies a sequence of operations on a list of inputs. The sequence of operations is produced by function :func:`decompose_einsum_equation`. :param seq: sequence of operations :param inputs: inputs :param verbose: verbosity :param kwargs: additional parameters, see `apply_sequence` in :class:`GraphEinsumSubOp <onnx_extended.tools.einsum.einsum_impl_classes.GraphEinsumSubOp>` :return: output .. runpython:: :showcode: import numpy from onnx_extended.tools.einsum import ( decompose_einsum_equation, apply_einsum_sequence) m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 m2 = numpy.arange(4).reshape((2, 2)) + 100 m3 = numpy.arange(8).reshape((2, 2, 2)) + 1000 seq = decompose_einsum_equation("bac,cd,def->ebc") res = apply_einsum_sequence(seq, m1, m2, m3) print(res) """ return seq.apply_sequence(*inputs, verbose=verbose, **kwargs)
def is_transpose_identity(perm: Tuple[int, ...]) -> bool: """ Tells if the permutation *perm* does nothing (itentity). :param perm: permutation :return: boolean """ return list(perm) == list(range(len(perm))) def _basic_verification( lengths: List[int], shapes: List[Tuple[int, ...]], equation: str ): assert len(lengths) - 1 == len( shapes ), "Equation %r has %d inputs but %d shapes are given." "" % ( equation, len(lengths), len(shapes), ) for i, (le, sh) in enumerate(zip(lengths, shapes)): assert le == len( sh ), "Inputs %d has %d dimensions but shapes %r has %d " " in equation %r." % ( i, le, sh, len(sh), equation, ) def _apply_transpose_reshape( op: Union[int, EinsumSubOp], row: str ) -> Iterable[EinsumSubOp]: """ Put all dimensions in the same order. :param op: integer (for one input) or an operator :param row: letter involved in this input (as a vector of binaries) :return: last created operator """ axes = [] p = 0 perm = [] for i, r in enumerate(row): if r == -1: axes.append((p, i)) else: p += 1 perm.append((r, i)) op = EinsumSubOp(len(row), "expand_dims", op, axes=tuple(axes)) yield op perm.sort() p = 0 new_perm = numpy.arange(len(row)) for i, r in enumerate(row): if r == -1: continue new_perm[perm[p][1]] = i p += 1 if not is_transpose_identity(new_perm): op = EinsumSubOp(len(row), "transpose", op, perm=tuple(new_perm)) yield op def _apply_squeeze_transpose( op: Union[int, EinsumSubOp], row_last: str, row_output: List[int] ) -> Iterable[EinsumSubOp]: """ Puts output dimension in the expected order. """ perm = [] sq = [] for i, d in enumerate(row_output): if d == -1: sq.append(i) else: perm.append((d, i)) perm.sort() new_perm = numpy.arange(len(row_last)) p = 0 for i, d in enumerate(row_output): if d == -1: continue new_perm[i] = perm[p][1] p += 1 perm = [p[1] for p in perm] if not is_transpose_identity(new_perm): op = EinsumSubOp(len(row_last), "transpose", op, perm=tuple(new_perm)) yield op if sq: op = EinsumSubOp(len(row_last), "squeeze", op, axes=tuple(sq)) yield op def _apply_einsum_matmul( fd, op1, op2, axes, left, right, ndim, op_matmul, row1, row2, verbose=False ) -> Iterable[EinsumSubOp]: """ Decomposes the generic matrix multiplication into numpy operations depending on the operator to use for matrix multiplication *op_matmul* (see :func:`decompose_einsum_equation`). """ allowed = {"matmul", "batch_dot", "dot"} assert ( op_matmul in allowed ), f"Unknown operator op_matmul={op_matmul!r} not in {allowed!r}." if op_matmul == "matmul": if verbose: print(f" -- MATMUL -> matmul axes={axes!r} left={left!r} right={right!r}") yield EinsumSubOp( fd, "matmul", op1, op2, axes=axes, left=left, right=right, ndim=ndim ) elif len(axes) == 0 and not (set(left) & set(right)): if verbose: print(f" -- MATMUL -> mul axes={axes!r} left={left!r} right={right!r}") yield EinsumSubOp(fd, "mul", op1, op2) elif not (set(axes) & set(left)) and not (set(axes) & set(right)): # No intersection between axes and right: matrix multiplication if verbose: print( " -- MATMUL -> batch_dot axes=%r left=%r right=%r" "" % (axes, left, right) ) all_axes = set(left) | set(right) | set(axes) common_axes = list(set(left) & set(right)) for i in range(ndim): if i not in all_axes: common_axes.append(i) common_axes.sort() # ReduceSum* has_dim = set(i for i in range(len(row1)) if row1[i] >= 0) right_no_left = (set(right) & has_dim) - (set(right) & (set(left) | set(axes))) if right_no_left: if verbose: print(f" -- MATMUL reduce1 has_dim={has_dim!r} axes={right_no_left!r}") op1 = EinsumSubOp( fd, "reduce_sum_mm", op1, op2, axes=tuple(sorted(right_no_left)) ) yield op1 has_dim = set(i for i in range(len(row2)) if row2[i] >= 0) left_no_right = (set(left) & has_dim) - (set(left) & (set(right) | set(axes))) if left_no_right: if verbose: print(f" -- MATMUL reduce2 has_dim={has_dim!r} axes={left_no_right!r}") op2 = EinsumSubOp(fd, "reduce_sum", op2, axes=tuple(sorted(left_no_right))) yield op2 # Transpose i_axes = [ (-1 if i in common_axes else (1 if i in axes else 0), i) for i in range(ndim) ] i_axes.sort() perm = [_[1] for _ in i_axes] perm_left = [i for i in range(len(perm)) if perm[i] in left] perm_right = [i for i in range(len(perm)) if perm[i] in right] if not is_transpose_identity(perm): op1 = EinsumSubOp(fd, "transpose_mm", op1, op2, perm=tuple(perm)) yield op1 op2 = EinsumSubOp(fd, "transpose", op2, perm=tuple(perm)) yield op2 # Reshape all_axes = list(range(0, ndim)) new_axes = all_axes[-len(axes) :] if len(axes) > 0 else [] new_common_axes = all_axes[: len(common_axes)] not_in_both = [] for i in range(0, ndim): if i not in left and i not in right and i not in common_axes: not_in_both.append(i) op = EinsumSubOp( fd, "batch_dot", op1, op2, batch_axes=tuple(new_common_axes), keep_axes=None, sum_axes=tuple(new_axes), left=tuple(perm_left), right=tuple(perm_right), ndim=ndim, ) yield op # Transpose again ordered_axes = ( common_axes + list(i for i in left if i not in right) + list(i for i in right if i not in left) + not_in_both ) rev_perm = [(a, i) for i, a in enumerate(ordered_axes)] rev_perm.sort() rev_perm = [p[1] for p in rev_perm] if not is_transpose_identity(rev_perm): op_unused = EinsumSubOp(fd, "transpose_mm", op1, op, perm=tuple(rev_perm)) yield op_unused op = EinsumSubOp(fd, "transpose", op, perm=tuple(rev_perm)) yield op else: raise NotImplementedError( "axes and right or left have axes in common, " "axes=%r left=%r right=%r ndim=%r." % (axes, left, right, ndim) ) def _decompose_einsum_equation_simple( equation: str, *shapes: List[Tuple[int, ...]], verbose: bool = False, op_matmul: str = "matmul", ) -> GraphEinsumSubOp: """ Applies strategy `simple`, `numpy` defined in by function :func:`decompose_einsum_equation`. :param equation: equation :param shapes: input shapes :param verbose: verbosity :param op_matmul: which operator to use for matrix multiplication, a single operator *matmul*, or *batch_dot* with *transposes*, *reduce_sum*, or just *dot* """ letters, mat, lengths, duplicates = analyse_einsum_equation(equation) assert ( len(letters) == mat.shape[1] ), f"Unexpected number of letters {letters!r}, shape={mat.shape!r}." if not shapes: shapes = [(2,) * le for le in lengths[:-1]] _basic_verification(lengths, shapes, equation) # last_row, current_row (row = shape) rows = numpy.full((2, mat.shape[1]), -1) graph = GraphEinsumSubOp(letters, mat, lengths, duplicates) fd = mat.shape[1] if verbose: print(f"EQUATION={equation!r}") print(f"LETTERS={letters!r}", f"LENGTHS={lengths!r}") print(f"DUPLICATES={duplicates!r}") for i, sh in enumerate(shapes): if verbose: print() print("######### ROW %d shape=%r row=%r" % (i, sh, rows[1, :])) graph.append(i) # Input matrix aligned to the same dimensions. op = EinsumSubOp(fd, "id", i) op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose) marked = graph.append(op) duplicate = duplicates[i] if duplicate is not None: # Diagonal diag = [] for _, v in duplicate.items(): if len(v) == 1: continue diag.append((v[0], tuple(v))) op = EinsumSubOp(fd, "diagonal", op, diag=diag) op.compute_output_row(rows[1, :], mat[i, :], verbose=verbose) tr_row = rows[1, :] marked = graph.append(op) else: diag = None tr_row = mat[i] for op in _apply_transpose_reshape(op, tr_row): op.compute_output_row(rows[1, :], verbose=verbose) marked = graph.append(op) # Reduction? (a dimension not used later) red = [] for d in range(0, mat.shape[1]): if mat[i + 1 :, d].max() == -1 and rows[1, d] != -1 and rows[0, d] == -1: red.append(d) if red: if verbose: print(" -- REDUCE1 row=%d axes=%r" % (i, red)) print(mat) print(" -") print(rows) op = EinsumSubOp(fd, "reduce_sum", graph.last_added_op, axes=tuple(red)) op.compute_output_row(rows[1, :], verbose=verbose) marked = graph.append(op) if graph.last_op is not None: # Matrix multiplication? common_dims = [] left = [] right = [] for d in range(0, mat.shape[1]): if rows[:, d].min() >= 0: if mat[i + 1 :, d].max() >= 0: left.append(d) right.append(d) else: common_dims.append(d) else: if rows[0, d] >= 0: left.append(d) if rows[1, d] >= 0: right.append(d) if verbose: print(f" -- MATMUL common_dims={common_dims!r}") print(rows) for iop in _apply_einsum_matmul( fd, graph.last_op, op, axes=tuple(common_dims), left=tuple(left), right=tuple(right), ndim=rows.shape[1], op_matmul=op_matmul, row1=rows[0, :], row2=rows[1, :], verbose=verbose, ): op = iop op.compute_output_row(rows[0, :], rows[1, :], ab=True, verbose=verbose) marked = graph.append(op) # End graph.mark(i, marked) rows[0, :] = rows[1, :] # Final output if verbose: print() print(f"######### FIN row={rows[1, :]!r}") if mat[len(shapes), :].max() >= 0: rows[1, :] = mat[len(shapes), :] red = [] for d in range(0, mat.shape[1]): if rows[0, d] > 0 and rows[1, d] == -1: red.append(d) elif rows[0, d] == -1 and rows[1, d] >= 0: raise RuntimeError( "Issue in equation %r, variable %d, last_result is %r, " "output is %r." % (equation, d, rows[0, :], rows[1, :]) ) if red: if verbose: print(f"-- REDUCE2 axes={red!r}") print(mat) op = EinsumSubOp(fd, "reduce_sum", op, axes=tuple(red)) graph.append(op) op.compute_output_row(rows[1, :], verbose=verbose) # Removes empty axes. for op in _apply_squeeze_transpose(op, rows[1, :], mat[len(shapes), :]): op.compute_output_row(rows[1, :], verbose=verbose) graph.append(op) return graph