Note
Go to the end to download the full example code.
101: Onnx Model Rewriting¶
This example shows how to rewrite a graph using a pattern.
A model¶
from typing import List, Optional
import onnx.helper as oh
from onnx import NodeProto, TensorProto
from experimental_experiment.helpers import pretty_onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.xbuilder.graph_builder import (
GraphBuilder,
OptimizationOptions,
)
from experimental_experiment.xoptim import EasyPatternOptimization
proto = oh.make_model(
oh.make_graph(
[
oh.make_node("Sigmoid", ["Y"], ["sy"]),
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
oh.make_node("Mul", ["X", "ysy"], ["final"]),
],
"nd",
[
oh.make_tensor_value_info("X", TensorProto.FLOAT, [1, "b", "c"]),
oh.make_tensor_value_info("Y", TensorProto.FLOAT, ["a", "b", "c"]),
],
[oh.make_tensor_value_info("final", TensorProto.FLOAT, ["a", "b", "c"])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
print(pretty_onnx(proto))
opset: domain='' version=18
input: name='X' type=dtype('float32') shape=[1, 'b', 'c']
input: name='Y' type=dtype('float32') shape=['a', 'b', 'c']
Sigmoid(Y) -> sy
Mul(Y, sy) -> ysy
Mul(X, ysy) -> final
output: name='final' type=dtype('float32') shape=['a', 'b', 'c']
And visually.
<Axes: >
The pattern¶
class MulMulSigmoidPattern(EasyPatternOptimization):
def match_pattern(self, g: GraphBuilder, X, Y):
return g.op.Mul(X, g.op.Mul(Y, g.op.Sigmoid(Y)))
def apply_pattern(self, g: GraphBuilder, X, Y):
return g.anyop.MulMulSigmoid(X, Y, domain="onnx_extended.ortops.optim.cuda")
Optimization¶
gr = GraphBuilder(
proto,
infer_shapes=True,
optimization_options=OptimizationOptions(
patterns=[MulMulSigmoidPattern(verbose=1)],
verbose=1, # a higher value increases the verbosity when optimizations for patterns
),
)
new_proto = gr.to_onnx()
print(pretty_onnx(new_proto))
[GraphBuilder.optimize] start with 3 nodes
[GraphBuilder.optimize] #patterns=1
[GraphBuilderPatternOptimization.optimize] start with 3 nodes, 0 initializers, 1 patterns, priorities=[0]
[GraphBuilderPatternOptimization.optimize] iteration 0: 3 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: MulMulSigmoidPattern replaces ['Sigmoid', 'Mul', 'Mul'] - time=0.001 | max_time=MulMulSigmoidPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 1: 1 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] done after 2 iterations with 1 nodes in 0.002
[GraphBuilder.optimize] done with 1 nodes in 0.002
opset: domain='' version=18
opset: domain='onnx_extended.ortops.optim.cuda' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='X' type=dtype('float32') shape=[1, 'b', 'c']
input: name='Y' type=dtype('float32') shape=['a', 'b', 'c']
MulMulSigmoid[onnx_extended.ortops.optim.cuda](X, Y) -> final
output: name='final' type=dtype('float32') shape=['a', 'b', 'c']
And visually.
<Axes: >
Filtering¶
Let’s assume now we want to apply the pattern only when the shapes are identical.
class MulMulSigmoidPattern2(EasyPatternOptimization):
def match_pattern(self, g: GraphBuilder, X, Y):
return g.op.Mul(X, g.op.Mul(Y, g.op.Sigmoid(Y)))
def apply_pattern(self, g: GraphBuilder, X, Y):
return g.anyop.MulMulSigmoid(X, Y, domain="onnx_extended.ortops.optim.cuda")
def validate_mapping(
self,
g: GraphBuilder,
deleted_nodes: List[NodeProto],
pattern_nodes: Optional[List[NodeProto]] = None,
) -> bool:
for node in deleted_nodes:
if (
node.op_type == "Mul"
and g.has_shape(node.input[0])
and g.has_shape(node.input[1])
):
sh1 = g.get_shape(node.input[0])
sh2 = g.get_shape(node.input[1])
if sh1 != sh2:
if self.verbose > 0:
print(
f"[MulMulSigmoidPattern2.validate_mapping] "
f"match not valid because shapes are different"
f"{node.input[0]}:{sh1} != {node.input[1]}:{sh2}"
)
return False
return True
gr = GraphBuilder(
proto,
infer_shapes=True,
optimization_options=OptimizationOptions(
patterns=[MulMulSigmoidPattern2(verbose=1)],
verbose=0,
),
)
new_proto = gr.to_onnx()
print(pretty_onnx(new_proto))
[MulMulSigmoidPattern2.validate_mapping] match not valid because shapes are differentX:(1, 'b', 'c') != ysy:('a', 'b', 'c')
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='X' type=dtype('float32') shape=[1, 'b', 'c']
input: name='Y' type=dtype('float32') shape=['a', 'b', 'c']
Sigmoid(Y) -> sy
Mul(Y, sy) -> ysy
Mul(X, ysy) -> final
output: name='final' type=dtype('float32') shape=['a', 'b', 'c']
Total running time of the script: (0 minutes 0.200 seconds)