Source code for experimental_experiment.xoptim.patterns.onnx_functions

import inspect
from typing import List, Optional
from onnx import NodeProto, TensorProto
from ..patterns_api import EasyPatternOptimization


[docs] class GeluPattern(EasyPatternOptimization): """ Detects the decomposed version of Gelu with Tanh .. math:: y = \\frac{x}{2} \\left(1 + \\tanh\\left(\\sqrt{\\frac{2}{\\pi}} (x + 0.044715 * x^3)\\right)\\right) """ def __init__( self, verbose: int = 0, priority: int = 0, min_opset: int = 20, domain: str = "" ): super().__init__(verbose, priority, min_opset=min_opset) self.domain = domain
[docs] def match_pattern(self, g: "GraphBuilder", x, c3, c04, cpi, one, c2): # noqa: F821 x3 = g.op.Pow(x, c3) # 3 cx3 = g.op.Mul(x3, c04) # 0.044715 add = g.op.Add(x, cx3) addm = g.op.Mul(add, cpi) # 0.7978515625 = 2/pi tanh = g.op.Tanh(addm) tanh1 = g.op.Add(tanh, one) # 1 x2 = g.op.Mul(x, c2) # 0.5 return g.op.Mul(x2, tanh1)
[docs] def apply_pattern( self, g: "GraphBuilder", # noqa: F821 x, c3, c04, cpi, one, c2, ): return g.op.Gelu(x, approximate="tanh", domain=self.domain)
[docs] def validate_mapping( self, g: "GraphBuilderPatternOptimization", # noqa: F821 deleted_nodes: List[NodeProto], pattern_nodes: Optional[List[NodeProto]] = None, ) -> bool: assert len(deleted_nodes) == 8, f"Unexpected pattern length {len(deleted_nodes)}" assert deleted_nodes[0].op_type == "Pow", f"-- {deleted_nodes[0]}" c3 = deleted_nodes[0].input[1] assert deleted_nodes[1].op_type == "Mul", f"-- {deleted_nodes[1]}" cx3 = deleted_nodes[1].input[1] assert deleted_nodes[3].op_type == "Mul", f"-- {deleted_nodes[3]}" cpi = deleted_nodes[3].input[1] assert deleted_nodes[5].op_type == "Add", f"-- {deleted_nodes[5]}" one = deleted_nodes[5].input[1] assert deleted_nodes[6].op_type == "Mul", f"-- {deleted_nodes[5]}" c2 = deleted_nodes[6].input[1] node = deleted_nodes[0] if not g.is_constant_scalar(c3) or g.get_constant_scalar(c3) != 3: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(cx3) or g.get_constant_scalar(cx3) not in ( 0.044715, 0.044708251953125, ): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(cpi) or g.get_constant_scalar(cpi) != 0.7978515625: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(one) or g.get_constant_scalar(one) != 1: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(c2) or g.get_constant_scalar(c2) != 0.5: return self.none(node, inspect.currentframe().f_lineno) return True
[docs] class LeakyReluPattern(EasyPatternOptimization): """ Detects the decomposed version of LeakyRelu. """ def __init__(self, verbose: int = 0, priority: int = 0, min_opset: int = 6): super().__init__(verbose, priority, min_opset=min_opset)
[docs] def match_pattern(self, g: "GraphBuilder", x, zero, slope): # noqa: F821 return g.op.Where(g.op.Greater(x, zero), x, g.op.Mul(x, slope))
[docs] def apply_pattern( self, g: "GraphBuilder", # noqa: F821 x, zero, slope, ): # g is not the GraphBuilder for the main graph. return g.op.LeakyRelu(x, alpha=self.get_validate_param("slope"))
[docs] def validate_mapping( self, g: "GraphBuilderPatternOptimization", # noqa: F821 deleted_nodes: List[NodeProto], pattern_nodes: Optional[List[NodeProto]] = None, ) -> bool: assert len(deleted_nodes) == 3, f"Unexpected pattern length {len(deleted_nodes)}" assert deleted_nodes[2].op_type == "Where", f"-- {deleted_nodes[0]}" greater, mul = ( (deleted_nodes[0], deleted_nodes[1]) if deleted_nodes[0].op_type == "Greater" else (deleted_nodes[1], deleted_nodes[0]) ) zero = greater.input[1] slope = mul.input[1] if not g.is_constant_scalar(zero) or g.get_constant_scalar(zero) != 0: return self.none(greater, inspect.currentframe().f_lineno) if not g.is_constant_scalar(slope): return self.none(mul, inspect.currentframe().f_lineno) self.add_validate_param("slope", g.get_constant_scalar(slope)) return True
[docs] class SoftmaxCrossEntropyLossCastPattern(EasyPatternOptimization): """ Detects one decomposed version of SoftmaxCrossEntropyLoss. """ def __init__( self, verbose: int = 0, priority: int = 0, min_opset: int = 14, domain: str = "" ): super().__init__(verbose, priority, min_opset=min_opset) self.domain = domain
[docs] def match_pattern( self, g: "GraphBuilder", # noqa: F821 X, indices, axis, zerof, zeroi, b, ): neq1 = g.op.Not(g.op.Equal(indices, b)) wh1 = g.op.Where(neq1, indices, zeroi) uns = g.op.Unsqueeze(wh1, axis) ge = g.op.GatherElements(g.op.LogSoftmax(X, axis=1), uns, axis=1) wh2 = g.op.Where(neq1, g.op.Neg(g.op.Squeeze(ge, axis)), zerof) denominator = g.op.Cast( g.op.ReduceSum( g.op.Cast(neq1, to=TensorProto.FLOAT), keepdims=0, ), to=TensorProto.FLOAT16, ) numerator = g.op.Cast( g.op.ReduceSum( g.op.Cast(wh2, to=TensorProto.FLOAT), keepdims=0, ), to=TensorProto.FLOAT16, ) return g.op.Div(numerator, denominator)
[docs] @classmethod def apply_pattern( cls, g: "GraphBuilder", # noqa: F821 X, indices, axis, zerof, zeroi, b, ): return g.op.SoftmaxCrossEntropyLoss(X, indices, ignore_index=-100, reduction="mean")
[docs] def validate_mapping( self, g: "GraphBuilderPatternOptimization", # noqa: F821 deleted_nodes: List[NodeProto], pattern_nodes: Optional[List[NodeProto]] = None, ) -> bool: assert len(deleted_nodes) == 16, f"Unexpected pattern length {len(deleted_nodes)}" node = deleted_nodes[-1] for n in deleted_nodes: if n.op_type in {"Squeeze", "Unsqueeze"}: c = n.input[1] if not g.is_constant_scalar(c) or g.get_constant_scalar(c) != 1: return self.none(node, inspect.currentframe().f_lineno) continue if n.op_type in {"Equal"}: c = n.input[1] if not g.is_constant_scalar(c) or g.get_constant_scalar(c) != -100: return self.none(node, inspect.currentframe().f_lineno) continue return True