.xoptim.patterns.onnx_functions

class experimental_experiment.xoptim.patterns.onnx_functions.GeluPattern(verbose: int = 0, priority: int = 0, min_opset: int = 20, domain: str = '')[source]

Detects the decomposed version of Gelu with Tanh

y = \frac{x}{2}
\left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} (x + 0.044715 * x^3)\right)\right)

Model with nodes to be fused:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="linear_5\nFLOAT16(4,512,16384)", fillcolor="#aaeeaa"];
  Pow_1 [label="Pow(., [3.0])", fillcolor="#cccccc"];
  Mul_2 [label="Mul(., 0.0447)", fillcolor="#cccccc"];
  Add_3 [label="Add(., .)", fillcolor="#cccccc"];
  Mul_4 [label="Mul(., 0.798)", fillcolor="#cccccc"];
  Tanh_5 [label="Tanh(.)", fillcolor="#cccccc"];
  Add_6 [label="Add(., 1.0)", fillcolor="#cccccc"];
  Mul_7 [label="Mul(., 0.5)", fillcolor="#cccccc"];
  Mul_8 [label="Mul(., .)", fillcolor="#cccccc"];
  I_0 -> Pow_1 [label="FLOAT16(4,512,16384)"];
  Pow_1 -> Mul_2 [label="FLOAT16(4,512,16384)"];
  I_0 -> Add_3 [label="FLOAT16(4,512,16384)"];
  Mul_2 -> Add_3 [label="FLOAT16(4,512,16384)"];
  Add_3 -> Mul_4 [label="FLOAT16(4,512,16384)"];
  Mul_4 -> Tanh_5 [label="FLOAT16(4,512,16384)"];
  Tanh_5 -> Add_6 [label="FLOAT16(4,512,16384)"];
  I_0 -> Mul_7 [label="FLOAT16(4,512,16384)"];
  Mul_7 -> Mul_8 [label="FLOAT16(4,512,16384)"];
  Add_6 -> Mul_8 [label="FLOAT16(4,512,16384)"];
  O_9 [label="mul_4\nFLOAT16(4,512,16384)", fillcolor="#aaaaee"];
  Mul_8 -> O_9;
}

Outcome of the fusion:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="linear_5\nFLOAT16(4,512,16384)", fillcolor="#aaeeaa"];
  Gelu_1 [label="Gelu(.)", fillcolor="#cccccc"];
  I_0 -> Gelu_1 [label="FLOAT16(4,512,16384)"];
  O_2 [label="mul_4\nFLOAT16(4,512,16384)", fillcolor="#aaaaee"];
  Gelu_1 -> O_2;
}
apply_pattern(g: GraphBuilder, x, c3, c04, cpi, one, c2)[source]

Applies the replacement.

match_pattern(g: GraphBuilder, x, c3, c04, cpi, one, c2)[source]

Builds the pattern to match.

validate_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool[source]

Validates the mapping.

Parameters:
  • g – GraphBuilder

  • deleted_nodes – matched nodes from the model (to be deleted)

  • pattern_nodes – matched nodes coming from the pattern

Returns:

validate the mapping or not, default is True

class experimental_experiment.xoptim.patterns.onnx_functions.LeakyReluPattern(verbose: int = 0, priority: int = 0, min_opset: int = 6)[source]

Detects the decomposed version of LeakyRelu.

Model with nodes to be fused:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="X1\nFLOAT(3,3)", fillcolor="#aaeeaa"];
  Greater_1 [label="Greater(., [0.0])", fillcolor="#cccccc"];
  Mul_2 [label="Mul(., [-0.33])", fillcolor="#cccccc"];
  Where_3 [label="Where(., ., .)", fillcolor="#cccccc"];
  I_0 -> Greater_1 [label="FLOAT(3,3)"];
  I_0 -> Mul_2 [label="FLOAT(3,3)"];
  Greater_1 -> Where_3 [label="BOOL(3,3)"];
  I_0 -> Where_3 [label="FLOAT(3,3)"];
  Mul_2 -> Where_3 [label="FLOAT(3,3)"];
  O_4 [label="Y\nFLOAT(3,3)", fillcolor="#aaaaee"];
  Where_3 -> O_4;
}

Outcome of the fusion:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="X1\nFLOAT(3,3)", fillcolor="#aaeeaa"];
  LeakyRelu_1 [label="LeakyRelu(.)", fillcolor="#cccccc"];
  I_0 -> LeakyRelu_1 [label="FLOAT(3,3)"];
  O_2 [label="Y\nFLOAT(3,3)", fillcolor="#aaaaee"];
  LeakyRelu_1 -> O_2;
}
apply_pattern(g: GraphBuilder, x, zero, slope)[source]

Applies the replacement.

match_pattern(g: GraphBuilder, x, zero, slope)[source]

Builds the pattern to match.

validate_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool[source]

Validates the mapping.

Parameters:
  • g – GraphBuilder

  • deleted_nodes – matched nodes from the model (to be deleted)

  • pattern_nodes – matched nodes coming from the pattern

Returns:

validate the mapping or not, default is True

class experimental_experiment.xoptim.patterns.onnx_functions.SoftmaxCrossEntropyLossCastPattern(verbose: int = 0, priority: int = 0, min_opset: int = 14, domain: str = '')[source]

Detects one decomposed version of SoftmaxCrossEntropyLoss.

Model with nodes to be fused:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="I\nINT64(A)", fillcolor="#aaeeaa"];
  I_1 [label="X\nFLOAT16(A,B)", fillcolor="#aaeeaa"];
  Equal_2 [label="Equal(., [-100])", fillcolor="#cccccc"];
  Not_3 [label="Not(.)", fillcolor="#cccccc"];
  Where_4 [label="Where(., ., [0])", fillcolor="#cccccc"];
  Unsqueeze_5 [label="Unsqueeze(., [1])", fillcolor="#eeeeee"];
  LogSoftmax_6 [label="LogSoftmax(., axis=1)", fillcolor="#cccccc"];
  GatherElements_7 [label="GatherElements(., ., axis=1)", fillcolor="#cccccc"];
  Squeeze_8 [label="Squeeze(., [1])", fillcolor="#eeeeee"];
  Neg_9 [label="Neg(.)", fillcolor="#cccccc"];
  Where_10 [label="Where(., ., [0.0])", fillcolor="#cccccc"];
  Cast_11 [label="Cast(., to=FLOAT)", fillcolor="#cccccc"];
  ReduceSum_12 [label="ReduceSum(.)", fillcolor="#cccccc"];
  Cast_13 [label="Cast(., to=FLOAT16)", fillcolor="#cccccc"];
  Cast_14 [label="Cast(., to=FLOAT)", fillcolor="#cccccc"];
  ReduceSum_15 [label="ReduceSum(.)", fillcolor="#cccccc"];
  Cast_16 [label="Cast(., to=FLOAT16)", fillcolor="#cccccc"];
  Div_17 [label="Div(., .)", fillcolor="#cccccc"];
  I_0 -> Equal_2 [label="INT64(A)"];
  Equal_2 -> Not_3 [label="BOOL(A)"];
  Not_3 -> Where_4 [label="BOOL(A)"];
  I_0 -> Where_4 [label="INT64(A)"];
  Where_4 -> Unsqueeze_5 [label="INT64(A)"];
  I_1 -> LogSoftmax_6 [label="FLOAT16(A,B)"];
  LogSoftmax_6 -> GatherElements_7 [label="FLOAT16(A,B)"];
  Unsqueeze_5 -> GatherElements_7 [label="INT64(A,1)"];
  GatherElements_7 -> Squeeze_8 [label="FLOAT16(A,1)"];
  Squeeze_8 -> Neg_9 [label="FLOAT16(A)"];
  Not_3 -> Where_10 [label="BOOL(A)"];
  Neg_9 -> Where_10 [label="FLOAT16(A)"];
  Not_3 -> Cast_11 [label="BOOL(A)"];
  Cast_11 -> ReduceSum_12 [label="FLOAT(A)"];
  ReduceSum_12 -> Cast_13 [label="FLOAT()"];
  Where_10 -> Cast_14 [label="FLOAT16(A)"];
  Cast_14 -> ReduceSum_15 [label="FLOAT(A)"];
  ReduceSum_15 -> Cast_16 [label="FLOAT()"];
  Cast_16 -> Div_17 [label="FLOAT16()"];
  Cast_13 -> Div_17 [label="FLOAT16()"];
  O_18 [label="Y\nFLOAT16()", fillcolor="#aaaaee"];
  Div_17 -> O_18;
}

Outcome of the fusion:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="I\nINT64(A)", fillcolor="#aaeeaa"];
  I_1 [label="X\nFLOAT16(A,B)", fillcolor="#aaeeaa"];
  SoftmaxCrossEntropyLoss_2 [label="SoftmaxCrossEntropyLoss(., .)", fillcolor="#cccccc"];
  I_1 -> SoftmaxCrossEntropyLoss_2 [label="FLOAT16(A,B)"];
  I_0 -> SoftmaxCrossEntropyLoss_2 [label="INT64(A)"];
  O_3 [label="Y\nFLOAT16()", fillcolor="#aaaaee"];
  SoftmaxCrossEntropyLoss_2 -> O_3;
}
classmethod apply_pattern(g: GraphBuilder, X, indices, axis, zerof, zeroi, b)[source]

Applies the replacement.

match_pattern(g: GraphBuilder, X, indices, axis, zerof, zeroi, b)[source]

Builds the pattern to match.

validate_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool[source]

Validates the mapping.

Parameters:
  • g – GraphBuilder

  • deleted_nodes – matched nodes from the model (to be deleted)

  • pattern_nodes – matched nodes coming from the pattern

Returns:

validate the mapping or not, default is True