.xoptim.patterns

class experimental_experiment.xoptim.patterns.AlmostDoNothingPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]

Checks that a Expand is really needed.

apply(g: GraphBuilder, node: NodeProto) List[NodeProto][source]

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optmizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

experimental_experiment.xoptim.patterns.get_default_patterns(verbose: int = 0) List[PatternOptimization][source]

Returns a default list of optimization patterns. It is equal to the following list.

<<<

from experimental_experiment.xoptim.patterns_api import pattern_table_doc
from experimental_experiment.xoptim.patterns import get_default_patterns

print(pattern_table_doc(get_default_patterns(), as_rst=True))

>>>

name

short_name

priority

doc

0

BatchNormalizationPattern

BatchNormalization

0

Checks that a BatchNormalization is really needed.

1

BatchNormalizationTrainingPattern

BatchNormalizationTraining

0

Checks that a BatchNormalization in training mode can be avoided.

2

CastLayerNormalizationCastPattern

CastLayerNormalizationCast

1

Checks that a Cast is really needed around LayerNormalization.

3

CastPattern

Cast

0

Checks that a Cast is really needed.

4

CastCastBinaryPattern

CastCastBinary

1

Moves two cast operators beyond a binary operator The cast must cast from a float type to another float type.

5

CastOpCastPattern

CastOpCast

1

Removes two cast surrounding another operator.

6

ClipClipPattern

ClipClip

1

Merges consecutive clips if one is defining min and the other max.

7

ComputationCastOpCastPattern

ComputationCastOpCast

1

Changes the computation type to make it faster if one of the inputs was just casted before.

8

ConcatEmptyPattern

ConcatEmpty

1

Checks if one of the concatenated values is empty.

9

ConcatGatherPattern

ConcatGather

0

Checks if Gather(Concat) can be replaced by Identity.

10

ConcatReshapePattern

ConcatReshape

0

Tries to reduce the number of nodes in the sequence Concat + Reshape by replacing one of the dimension by -1.

11

ConcatTwiceUnaryPattern

ConcatTwiceUnary

1

Sin(Concat(x,x)) -> Concat(Sin(x), Sin(x)).

12

ConvBiasNullPattern

ConvBiasNull

0

Checks that a Conv has a null bias.

13

DropoutPattern

Dropout

1

Checks that a Cast is really needed.

14

ExpandPattern

Expand

0

Checks that a Expand is really needed.

15

ExpandBroadcastPattern

ExpandBroadcast

1

Checks that a Expand is really needed before an element wise operator. The objective is to save one allocation and let the next operator do the expansion by broadcasting one input.

16

ExpandSwapPattern

ExpandSwap

1

Tries to move a node Expand forward in the graph. Expand + Exp can be changed into Exp + Expand. Then Exp applies on a tensor of a smaller or equal size.

17

GeluPattern

Gelu

0

Detects the decomposed version of Gelu with Tanh .. math

18

IdentityPattern

Identity

0

Replaces operator such as Div(X, 1), Mul(X, 1), Add(X, 0), Sub(X, 0), Transpose(X, [0, 1, 2, …]) by identity nodes.

19

LayerNormalizationPattern

LayerNormalization

1

Fuses nodes of a LayerNormalization.

20

LayerNormalizationScalePattern

LayerNormalizationScale

1

Fused LayerNormalization, scale, bias just after.

21

LeakyReluPattern

LeakyRelu

0

Detects the decomposed version of LeakyRelu.

22

MulMulMulScalarPattern

MulMulMulScalar

1

Replaces the sequence {Div | Mul} and {Div | Mul} + {Div | Mul} with {Div | Mul} Mul.

23

ReduceReshapePattern

ReduceReshape

1

Replaces the sequence Reduce* Reshape if reshape is only introduces to deal with a dimension kept because keepdims=1.

24

ReduceSumNormalizePattern

ReduceSumNormalize

1

Nodes equivalent to a reduction.

25

ReshapePattern

Reshape

0

Checks that a Reshape is really needed.

26

ReshapeMatMulReshapePattern

ReshapeMatMulReshape

1

Replaces the sequence Reshape, Matmul, Reshape by Matmul.

27

Reshape2Of3Pattern

Reshape2Of3

1

Replaces the reshapes around element-wise operators. It can be 3 or 2 out of 3.

28

ReshapeReshapeBinaryPattern

ReshapeReshapeBinary

1

Moves two reshape operators beyond a binary operator if it is possible.

29

MatMulAddPattern

MatMulAdd

3

Replaces the sequence MatMul, Add into Gemm. By default, no reshape is allowed this happens only it is two dimensions.

30

GemmTransposePattern

GemmTranspose

1

Replaces Gemm (., constant) by Gemm(., constant’, transB=1)

31

MatMulReshape2Of3Pattern

MatMulReshape2Of3

1

Replaces the reshapes around a matmul It can be 3 or 2 out of 3. It is similar to experimental_experiment.xoptim.patterns.onnx_reshape.Reshape2Of3Pattern.

32

MulMulMatMulPattern

MulMulMatMul

1

Replaces MatMul(a*c, b*d) where c and d are constant scalar by MatMul(a,b) * (c,d).

33

ShapeBasedReshapeIsSqueezePattern

ShapeBasedReshapeIsSqueeze

0

Replaces a replaces by a squeeze or unsqueeze pattern if possible. It is only available for opset < 18.

34

ShapeBasedStaticExpandPattern

ShapeBasedStaticExpand

0

Compares input and output shapes to tell if the expand can uses a constant as a second input.

35

ShapeBasedConcatExpandPattern

ShapeBasedConcatExpand

1

Rewrites Expand(X, concat(…)) if possible.

36

ShapeBasedEditDistanceReshapePattern

ShapeBasedEditDistanceReshape

0

Tries to reduce the number of nodes in the sequence Concat + Reshape by replacing one of the dimension by -1 or 0. The pattern tries to align shape information to infer a static shape.

37

ShapeBasedIdentityPattern

ShapeBasedIdentity

0

If a slice leads to the same shape and the step is 1 then it is identity. In some cases, just known the same is enough to replace them.

38

ShapeBasedExpandBroadcastPattern

ShapeBasedExpandBroadcast

1

Similar to experimental_experiment.xoptim.patterns.onnx_expand.ExpandBroadcastPattern, but it allows dynamic shapes as well. It does not look into the second argument of Expand, it just infers than an expand is not needed for a binary operator following just after.

39

ShapeBasedExpandBroadcastMatMulPattern

ShapeBasedExpandBroadcastMatMul

1

Similar to experimental_experiment.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern, but works only with MatMul.

40

ShapeBasedExpandCastWhereSwapPattern

ShapeBasedExpandCastWhereSwap

1

Rewrites Where(Cast(X), X, cond).

41

ShapeBasedExpandSwapPattern

ShapeBasedExpandSwap

1

Tries to move a node Expand forward in the graph for a binary operator. The code is similar to experimental_experiment.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern

42

ShapeBasedMatMulToMulPattern

ShapeBasedMatMulToMul

1

MatMul can be replaced by Mul with broadcast. It makes it easier to detect optimization pattern with Expand operators.

43

ShapeBasedSameChildrenPattern

ShapeBasedSameChildren

0

Checks there is no duplicated node doing the same than another one beside. experimental_experiment.xoptim.patterns.onnx_any.SameChildrenPattern checks it is exactly the same. This one assumes it is exactly the same in some cases such expand (X, sh1) = expand(X, sh2) if the output shapes are the same.

44

ShapeBasedShapeShapeAddPattern

ShapeBasedShapeShapeAdd

0

Tries to find another to get a dimension obtained with the addition of two.

45

ReshapeReshapePattern

ReshapeReshape

0

Replaces the sequence Reshape, Reshape by Reshape.

46

RotaryEmbeddingPattern

RotaryEmbedding

1

Fuses nodes matching RotaryEmbedding(23).

47

SameChildrenPattern

SameChildren

0

Checks there is no duplicated node doing the same than another one beside.

48

SequenceConstructAtPattern

SequenceConstructAt

1

Replaces the sequence SequenceConstruct(x1, x2, ...) followed by SequenceAt(seq, 0), SequenceAt(seq, 1), …

49

SliceSlicePattern

SliceSlice

1

Merges consecutive slices if axis are disjoints.

50

SlicesSplitPattern

SlicesSplit

1

Merges multiple parallel slices into a split.

51

SoftmaxCrossEntropyLossCastPattern

SoftmaxCrossEntropyLossCast

0

Detects one decomposed version of SoftmaxCrossEntropyLoss.

52

SplitConcatPattern

SplitConcat

1

Replaces Split + Concat into identity if this is equivalent.

53

SqueezeAddPattern

SqueezeAdd

0

Replaces the sequence Add(Squeeze, Squeeze) by Squeeze(Add).

54

SqueezeUnsqueezePattern

SqueezeUnsqueeze

0

Replaces the sequence Squeeze, Unsqueeze by Identity or the other ways around.

55

StaticConcatReshapePattern

StaticConcatReshape

0

Tries to reduce the number of nodes in the sequence Concat + Reshape by replacing one of the dimension by -1.

56

Sub1MulPattern

Sub1Mul

1

Replaces the sequence (1 - X) x Y by Y - X x Y to avoid the creation of a constant in the graph. x means element wise multiplication.

57

SwitchOrderBinaryPattern

SwitchOrderBinary

1

If it makes sense, switches the order of two multiplications or two addtions if the broadcasting reduces one operator to a an insignificant number.

58

SwitchReshapeActivationPattern

SwitchReshapeActivation

1

Swiches Gelu and Reshape after a Gemm or a MatMul. Gelu can also be Gelu, Exp, Elu, Relu, Tan, Tanh, Cos, Cosh, Sin, Sinh, Erf, LeakyRelu, PRelu, Selu, Softmax, Softplus. Reshape can also be Transpose.

59

TransposeEqualReshapePattern

TransposeEqualReshape

1

Replaces a Transpose by a Reshape when switched dimensions are all equal to 1 but one.

60

TransposeMatMulPattern

TransposeMatMul

1

Replaces the sequence Transpose, Matmul or Gemm into Gemm

61

TransposeReshapeMatMulPattern

TransposeReshapeMatMul

1

Replaces the sequence Transpose, Reshape, Matmul into Reshape, Transpose, Matmul if possible. Another optimizer will optimizes this sequence by using Gemm or better.

62

TransposeReshapeTransposePattern

TransposeReshapeTranspose

0

Swaps Reshape and Transpose in a sequence such as this one:

63

TransposeTransposePattern

TransposeTranspose

0

Removes two consecutive transpose if the second one put the tensor in origin shape.

64

UnsqueezeEqualPattern

UnsqueezeEqual

1

Replaces the sequence R -> Equal -> Unsqueeze, R -> Unsqueeze, into R -> Unsqueeze -> Equal.

65

UnsqueezeUnsqueezePattern

UnsqueezeUnsqueeze

0

Replaces the sequence Unsqueeze, Unsqueeze by Unsqueeze.

66

RotaryConcatPartPattern

RotaryConcatPart

1

Optimizes the following pattern .. plot

67

FunctionCausalMaskPattern

FunctionCausalMask

1

Fuses nodes matching CausalMask into a local function. .. runpython

68

FunctionCausalMaskMulAddPattern

FunctionCausalMaskMulAdd

1

Fuses nodes matching CausalMask into a local function. .. runpython

69

FunctionCosSinCachePattern

FunctionCosSinCache

1

Fuses nodes to simplify the creation of cos/sin caches in LLM. .. runpython

70

FunctionHalfRotaryEmbeddingPattern

FunctionHalfRotaryEmbedding

1

Fuses nodes matching half RotaryEmbedding(23) into a local function. .. runpython

71

RMSNormalizationPattern

RMSNormalization

1

Fuses the nodes equivalent to RMSNormalization(23).