Source code for experimental_experiment.xoptim.patterns

import inspect
from typing import List, Optional
from onnx import NodeProto
from ..patterns_api import PatternOptimization, MatchResult

# onnx patterns
from .onnx_any import IdentityPattern, SameChildrenPattern
from .onnx_cast import (
    CastPattern,
    CastCastBinaryPattern,
    CastOpCastPattern,
    ComputationCastOpCastPattern,
)
from .onnx_conv import ConvBiasNullPattern
from .onnx_dropout import DropoutPattern
from .onnx_equal import UnsqueezeEqualPattern
from .onnx_expand import ExpandPattern, ExpandBroadcastPattern, ExpandSwapPattern
from .onnx_functions import GeluPattern, LeakyReluPattern, SoftmaxCrossEntropyLossCastPattern
from .onnx_layer_normalization import (
    BatchNormalizationPattern,
    BatchNormalizationTrainingPattern,
    CastLayerNormalizationCastPattern,
    LayerNormalizationPattern,
    LayerNormalizationScalePattern,
)
from .onnx_mul import (
    MulMulMulScalarPattern,
    SwitchOrderBinaryPattern,
)
from .onnx_matmul import (
    GemmTransposePattern,
    MatMulAddPattern,
    MatMulReshape2Of3Pattern,
    MulMulMatMulPattern,
    ReshapeMatMulReshapePattern,
    TransposeMatMulPattern,
    TransposeReshapeMatMulPattern,
)
from .onnx_reduce import ReduceSumNormalizePattern
from .onnx_reshape import (
    ReshapePattern,
    ReduceReshapePattern,
    Reshape2Of3Pattern,
    ReshapeReshapeBinaryPattern,
    ReshapeReshapePattern,
)
from .onnx_rotary import RotaryConcatPartPattern
from .onnx_split import SlicesSplitPattern
from .onnx_sub import Sub1MulPattern
from .onnx_transpose import TransposeTransposePattern, TransposeReshapeTransposePattern
from .onnx_unsqueeze import UnsqueezeUnsqueezePattern


[docs] class AlmostDoNothingPattern(PatternOptimization): """ Checks that a Expand is really needed. """ n_count = 0
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if self.n_count >= 0: return self.none() if node.op_type != "Pow" or node.domain != "": return self.none() if node.name is not None and "AlmostDoNothing" in node.name: return self.none(node, inspect.currentframe().f_lineno) self.n_count += 1 return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: return [ g.make_node( node.op_type, node.input, node.output, name=f"AlmostDoNothing--{node.name}", ) ]
[docs] def get_default_patterns(verbose: int = 0) -> List[PatternOptimization]: """ Returns a default list of optimization patterns. It is equal to the following list. .. runpython:: :showcode: import pprint from experimental_experiment.xoptim.patterns import get_default_patterns pprint.pprint(get_default_patterns()) """ return [ # AlmostDoNothingPattern(verbose=verbose), BatchNormalizationPattern(verbose=verbose), BatchNormalizationTrainingPattern(verbose=verbose), CastLayerNormalizationCastPattern(verbose=verbose), CastPattern(verbose=verbose), CastCastBinaryPattern(verbose=verbose), CastOpCastPattern(verbose=verbose), ComputationCastOpCastPattern(verbose=verbose), ConvBiasNullPattern(verbose=verbose), DropoutPattern(verbose=verbose), ExpandPattern(verbose=verbose), ExpandBroadcastPattern(verbose=verbose), ExpandSwapPattern(verbose=verbose), GeluPattern(verbose=verbose), IdentityPattern(verbose=verbose), LayerNormalizationPattern(verbose=verbose), LayerNormalizationScalePattern(verbose=verbose), LeakyReluPattern(verbose=verbose), MulMulMulScalarPattern(verbose=verbose), ReduceReshapePattern(verbose=verbose), ReduceSumNormalizePattern(verbose=verbose), ReshapePattern(verbose=verbose), ReshapeMatMulReshapePattern(verbose=verbose), Reshape2Of3Pattern(verbose=verbose), ReshapeReshapeBinaryPattern(verbose=verbose), MatMulAddPattern(verbose=verbose), GemmTransposePattern(verbose=verbose), MatMulReshape2Of3Pattern(verbose=verbose), MulMulMatMulPattern(verbose=verbose), ReshapeReshapePattern(verbose=verbose), RotaryConcatPartPattern(verbose=verbose), SameChildrenPattern(verbose=verbose), SlicesSplitPattern(verbose=verbose), SoftmaxCrossEntropyLossCastPattern(verbose=verbose), Sub1MulPattern(verbose=verbose), SwitchOrderBinaryPattern(verbose=verbose), TransposeMatMulPattern(verbose=verbose), TransposeReshapeMatMulPattern(verbose=verbose), TransposeReshapeTransposePattern(verbose=verbose), TransposeTransposePattern(verbose=verbose), UnsqueezeEqualPattern(verbose=verbose), UnsqueezeUnsqueezePattern(verbose=verbose), ]