Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
_B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.1491379737854, max=4.618813514709473
First conversion to ONNX¶
The conversion relies on torch.export.export()
.
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='SliceSlicePattern_init7_s1_0_start' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilder.constant_folding.from/fold(init7_s1_0)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilder.constant_folding.from/fold(init7_s1_30)##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, init1_s_2::RSh1) -> eq
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-4.1491379737854, max=4.618813514709473
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[0.0547443...) -> bias
Add(_onx_matmul_cat, bias) -> attention
Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0113851...) -> bias2
Add(_onx_matmul_relu, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-4.1491379737854, max=4.618813514709473
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[0.0547443...) -> bias
Add(_onx_matmul_cat, bias) -> attention
Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0113851...) -> bias2
Add(_onx_matmul_relu, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-CQY.optimize] start with 73 nodes
[GraphBuilder-CQY.optimize] #patterns=92
[GraphBuilder-CQY.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-CQY.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-CQY.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-CQY.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-CQY.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-CQY.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-CQY.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-CQY.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-CQY.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-CQY.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-CQY.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-CQY.optimize] start with 53 nodes, 28 initializers, 92 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 1/92 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 2/92 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 3/92 - P0 - CastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 4/92 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 5/92 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 6/92 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 7/92 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 8/92 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 9/92 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 10/92 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 11/92 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 12/92 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 13/92 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 14/92 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 15/92 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 16/92 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 17/92 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 18/92 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 19/92 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 20/92 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 21/92 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 22/92 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 23/92 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 24/92 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 25/92 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 26/92 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 27/92 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 28/92 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 29/92 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 30/92 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 31/92 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 32/92 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 33/92 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 34/92 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 35/92 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 36/92 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 37/92 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 38/92 - P1 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 39/92 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 40/92 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 41/92 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 42/92 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 43/92 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 44/92 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 45/92 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 46/92 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 47/92 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 48/92 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 49/92 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 50/92 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 51/92 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 52/92 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 53/92 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 54/92 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 55/92 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 56/92 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 57/92 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 58/92 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 59/92 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 60/92 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 61/92 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 62/92 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 63/92 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 64/92 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 65/92 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 66/92 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 67/92 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 68/92 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 69/92 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 70/92 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 71/92 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 72/92 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 73/92 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 74/92 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 75/92 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 76/92 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 77/92 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 78/92 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 79/92 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 80/92 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 81/92 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 82/92 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 83/92 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 84/92 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 85/92 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 86/92 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 87/92 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 88/92 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 89/92 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 90/92 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 91/92 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern 92/92 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-CQY.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-CQY.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-CQY.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-CQY.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-CQY.optimize] increase priority to 1
[GraphBuilderPatternOptimization-CQY.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-CQY.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-CQY.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-CQY.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-CQY.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-CQY.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-CQY.optimize] increase priority to 2
[GraphBuilderPatternOptimization-CQY.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-CQY.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilder-CQY.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-CQY.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-CQY.optimize] increase priority to 3
[GraphBuilderPatternOptimization-CQY.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-CQY.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-CQY.optimize] done after 8 iterations with 29 nodes in 0.035
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00014969099902373273
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0004874780006502988
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00026582700229482725
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004067619993293192
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00016573200082348194
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0009638310002628714
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00012380699990899302
STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001013454999338137
STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008280439978989307
STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.132400004484225e-05
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00021630700030073058
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00018701699809753336
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012142100058554206
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001161640029749833
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00034005799898295663
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014117699720372912
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003156380025757244
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.000210778001928702
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000125485999888042
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020566500097629614
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019073399926128332
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023604700254509225
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001860960019257618
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014363699847308453
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001273160014534369
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001797380009520566
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011473199992906302
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011958300274272915
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00019102400074189063
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011557599827938247
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001246270003321115
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001639339989196742
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011912799891433679
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012439199235814158
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011644099868135527
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=5.650899947795551e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00012853699990955647
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00029809500119881704
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.78490001114551e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.3609001042786986e-05
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.001709799000309431
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0020909830018354114
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=6.7550026869867e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012081099885108415
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0017317369984084507
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00025622099929023534
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014180600192048587
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0021210819995758357
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.162699926586356e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005160870005056495
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002892239990615053
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001922980009112507
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001770939979905961
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011977900066995062
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000114618000225164
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016320499889843632
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012286600394872949
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003638580001279479
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.9344999600434676e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026351499946031254
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020138600120844785
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023555299958388787
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017352400027448311
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016819400298118126
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013320899779500905
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005386399971030187
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012667300325119868
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001282800003536977
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017276099970331416
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043448399810586125
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032721499701438006
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015029099813546054
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030225499904190656
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019854699712595902
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003244519994041184
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020433000099728815
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018761299907055218
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023127600252337288
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001735459973133402
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023400800455419812
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00015971799984981772
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012341500041657127
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001424369984306395
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001272450008400483
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002935329999672831
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001218049965245882
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012243399760336615
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027939600113313645
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001865780013758922
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001748650029185228
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013498599946615286
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035561500408221036
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018729499970504548
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000184643000466167
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.041100020112935e-05
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002750849998847116
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002799320027406793
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001963239992619492
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019400299970584456
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020028600192745216
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017426000340492465
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0006548620003741235
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0019251579979027156
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-CQY.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-CQY.optimize] done with 29 nodes in 0.043
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft]
which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias) -> attention
SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-IMQ.optimize] start with 73 nodes
[GraphBuilder-IMQ.optimize] #patterns=92
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-IMQ.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-IMQ.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-IMQ.optimize] start with 53 nodes, 28 initializers, 92 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 1/92 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 2/92 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 3/92 - P0 - CastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 4/92 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 5/92 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 6/92 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 7/92 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 8/92 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 9/92 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 10/92 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 11/92 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 12/92 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 13/92 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 14/92 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 15/92 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 16/92 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 17/92 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 18/92 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 19/92 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 20/92 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 21/92 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 22/92 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 23/92 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 24/92 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 25/92 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 26/92 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 27/92 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 28/92 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 29/92 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 30/92 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 31/92 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 32/92 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 33/92 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 34/92 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 35/92 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 36/92 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 37/92 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 38/92 - P1 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 39/92 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 40/92 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 41/92 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 42/92 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 43/92 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 44/92 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 45/92 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 46/92 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 47/92 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 48/92 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 49/92 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 50/92 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 51/92 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 52/92 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 53/92 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 54/92 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 55/92 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 56/92 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 57/92 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 58/92 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 59/92 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 60/92 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 61/92 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 62/92 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 63/92 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 64/92 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 65/92 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 66/92 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 67/92 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 68/92 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 69/92 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 70/92 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 71/92 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 72/92 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 73/92 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 74/92 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 75/92 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 76/92 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 77/92 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 78/92 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 79/92 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 80/92 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 81/92 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 82/92 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 83/92 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 84/92 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 85/92 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 86/92 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 87/92 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 88/92 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 89/92 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 90/92 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 91/92 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern 92/92 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-IMQ.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-IMQ.optimize] increase priority to 1
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-IMQ.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-IMQ.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-IMQ.optimize] increase priority to 2
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-IMQ.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=ShapeBasedExpandBroadcastPattern:0.000
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-IMQ.optimize] increase priority to 3
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-IMQ.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-IMQ.optimize] done after 8 iterations with 29 nodes in 0.038
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00014954299876990262
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0005257770008029183
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00021109499539306853
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0003963510007451987
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0001277029987249989
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0009879899989755359
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010936199942079838
STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001020333000269602
STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.000959601997237769
STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.802400043350644e-05
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023113700262911152
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0001938599980348954
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014631800149800256
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002773910000541946
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00036972099769627675
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017115199989348184
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003359580023243325
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00021366299915825948
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013018200115766376
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00027450100242276676
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021505200311366934
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002359770005568862
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001970469984371448
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001500750022387365
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001401189983880613
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00017816000399761833
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011399300092307385
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012389000039547682
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00018147999981010798
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012342700392764527
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013117599883116782
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002151770022464916
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012704699656751473
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013118899914843496
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012282100033189636
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.013999700371642e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0001327500012848759
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0003122279995295685
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.7900000936351717e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.283999962557573e-05
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0018348789999436121
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0024085169989120914
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.433000064338557e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012835999586968683
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.002097357997627114
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002549539985921001
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015446699944732245
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0020164580000709975
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.646799945679959e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005335019995982293
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030084299760346767
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017048499830707442
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022499699844047427
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012595100270118564
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029390700001385994
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016994800353131723
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000127850998978829
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003754720000870293
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.03159997909097e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027197800045541953
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000184544001967879
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002436279974062927
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019114799943054095
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017875799676403403
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011973899927397724
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005322800025169272
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013045500236330554
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001320669998676749
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017762900097295642
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039467299939133227
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005254629959381418
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015358899872808252
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003153239958919585
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022683799943479244
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003335820001666434
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023295500068343244
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019475999943097122
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023828400298953056
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018373899729340337
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022436499966715928
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0001598860017111292
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013215099897934124
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000124246000268613
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012736800090351608
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0037673959996027406
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014079099855734967
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001317490005021682
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003830739951808937
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021182200180192012
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001850079988798825
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014682700020784978
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039244999607035425
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022648400226898957
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024795500030450057
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.21280011930503e-05
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030491500001517124
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003235440017306246
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020645799850171898
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021316400125215296
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022563499987882096
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017694100097287446
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0009510040017630672
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0018855740017897915
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-IMQ.optimize] done with 29 nodes in 0.045
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
res1, res2, align, dc = compare_onnx_execution(
onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 63 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 63 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 63 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 2:16x16 AZAA p_
004 ~ | INITIA float32 2:16x16 AZAA p_ | INITIA float32 2:16x16 BAAZ p_
005 ~ | INITIA float32 2:16x16 BAAZ p_ | INITIA float32 2:16x16 YZBZ p_
006 ~ | INITIA float32 2:16x16 YZBZ p_ | INITIA float32 2:30x30 KGSP sl
007 = | INITIA float32 1:1 AAAA in | INITIA float32 1:1 AAAA in
008 ~ | INITIA float32 1:1 AAAA in | INITIA float32 2:16x16 AABA p_
009 ~ | INITIA float32 2:16x16 AABA p_ | INITIA float32 2:16x16 BAAA p_
010 ~ | INITIA float32 2:16x16 BAAA p_ | INITIA float32 2:16x16 ZBCA p_
011 ~ | INITIA float32 2:16x16 ZBCA p_ | INITIA float32 2:30x30 KGSP sl
012 = | INITIA float32 2:32x16 AAAA p_ | INITIA float32 2:32x16 AAAA p_
013 = | INITIA float32 2:16x128 XFBD p_ | INITIA float32 2:16x128 XFBD p_
014 = | INITIA float32 2:128x16 ACBA p_ | INITIA float32 2:128x16 ACBA p_
015 = | INITIA float32 1:16 EEEE in | INITIA float32 1:16 EEEE in
016 = | INITIA float32 1:16 AAAA in | INITIA float32 1:16 AAAA in
017 - | INITIA int64 1:2 AAAA Sl |
018 - | INITIA int64 1:2 EEAA Sl |
019 - | INITIA int64 1:2 ABAA Sl |
020 = | INITIA float32 2:1024x16 RSYN em | INITIA float32 2:1024x16 RSYN em
021 = | INITIA float32 2:1024x16 QUFT em | INITIA float32 2:1024x16 QUFT em
022 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
023 = | INITIA float32 1:128 ZAAA de | INITIA float32 1:128 ZAAA de
024 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
025 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
026 = | RESULT float32 3:1x30x16 JFNV Gather em | RESULT float32 3:1x30x16 JFNV Gather em
027 = | RESULT float32 3:1x30x16 OEIP Gather em | RESULT float32 3:1x30x16 OEIP Gather em
028 ~ | RESULT float32 3:1x30x16 WKWJ Add ad | RESULT float32 3:1x30x16 ZBAA SkipLayerNormal _o
029 ~ | RESULT float32 3:1x30x16 ZBAA LayerNormalizat _o | RESULT float32 3:1x30x1 ZAAA SkipLayerNormal un
030 ~ | RESULT float32 3:1x30x16 ATFB MatMul li | RESULT float32 3:1x30x1 GGFE SkipLayerNormal un
031 ~ | RESULT float32 3:1x30x16 ZEWY MatMul li | RESULT float32 3:1x30x16 WKWJ SkipLayerNormal ad
032 ~ | RESULT float32 3:1x30x16 IAFR MatMul li | RESULT float32 3:1x30x16 ATFB MatMul li
033 ~ | RESULT float32 3:1x16x30 ZZZA Transpose tr | RESULT float32 3:1x30x16 ZEWY MatMul li
034 ~ | RESULT float32 3:1x30x30 QWMM MatMul ma | RESULT float32 3:1x30x30 YFXX FusedMatMul _o
035 ~ | RESULT float32 3:1x30x30 YFXX Mul _o | RESULT float32 3:1x30x16 IAFR MatMul li
036 - | RESULT float32 2:30x30 KGSP Slice sl |
037 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
038 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
039 = | RESULT float32 3:1x30x30 IHHH Softmax so | RESULT float32 3:1x30x30 IHHH Softmax so
040 = | RESULT float32 3:1x30x16 DGCH MatMul ma | RESULT float32 3:1x30x16 DGCH MatMul ma
041 = | RESULT float32 3:1x30x16 CCAF MatMul li | RESULT float32 3:1x30x16 CCAF MatMul li
042 = | RESULT float32 3:1x30x16 UCAT MatMul li | RESULT float32 3:1x30x16 UCAT MatMul li
043 ~ | RESULT float32 3:1x30x16 EXDP MatMul li | RESULT float32 3:1x30x30 AXAX FusedMatMul _o
044 ~ | RESULT float32 3:1x16x30 OCCX Transpose tr | RESULT float32 3:1x30x16 EXDP MatMul li
045 ~ | RESULT float32 3:1x30x30 AODM MatMul ma | RESULT bool 2:30x30 HLZC Equal eq
046 ~ | RESULT float32 3:1x30x30 AXAX Mul _o | RESULT float32 3:1x30x30 ???? Where ma
047 ~ | RESULT float32 2:30x30 KGSP Slice sl | RESULT float32 3:1x30x30 IHHH Softmax so
048 - | RESULT bool 2:30x30 HLZC Equal eq |
049 ~ | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x16 HBAC MatMul ma
050 ~ | RESULT float32 3:1x30x30 IHHH Softmax so | RESULT float32 3:1x30x32 LHCK Concat ca
051 ~ | RESULT float32 3:1x30x16 HBAC MatMul ma | RESULT float32 3:1x30x16 AAAA MatMul _o
052 ~ | RESULT float32 3:1x30x32 LHCK Concat ca | RESULT float32 3:1x30x16 WWVV Add li
053 ~ | RESULT float32 3:1x30x16 AAAA MatMul _o | RESULT float32 3:1x30x16 ZBAA SkipLayerNormal _o
054 ~ | RESULT float32 3:1x30x16 WWVV Add li | RESULT float32 3:1x30x1 YZAZ SkipLayerNormal un
055 ~ | RESULT float32 3:1x30x16 SGRE Add ad | RESULT float32 3:1x30x1 GGFE SkipLayerNormal un
056 ~ | RESULT float32 3:1x30x16 ZBAA LayerNormalizat _o | RESULT float32 3:1x30x16 SGRE SkipLayerNormal ad
057 = | RESULT float32 3:1x30x128 AZSV MatMul _o | RESULT float32 3:1x30x128 AZSV MatMul _o
058 = | RESULT float32 3:1x30x128 XXOS Add li | RESULT float32 3:1x30x128 XXOS Add li
059 = | RESULT float32 3:1x30x128 RDDF Relu re | RESULT float32 3:1x30x128 RDDF Relu re
060 = | RESULT float32 3:1x30x16 EDHJ MatMul _o | RESULT float32 3:1x30x16 EDHJ MatMul _o
061 = | RESULT float32 3:1x30x16 HGKL Add li | RESULT float32 3:1x30x16 HGKL Add li
062 = | RESULT float32 3:1x30x16 ZMCQ Add ou | RESULT float32 3:1x30x16 ZMCQ Add ou
063 = | OUTPUT float32 3:1x30x16 ZMCQ ou | OUTPUT float32 3:1x30x16 ZMCQ ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 1.560 seconds)
Related examples

to_onnx and padding one dimension to a mulitple of a constant