Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
_B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.608179092407227, max=4.440975189208984
First conversion to ONNX¶
The conversion relies on torch.export.export().
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias, 1e-05, False), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias, 1e-05, False), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-4.608178615570068, max=4.440975189208984
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias_cst2init) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-4.608178615570068, max=4.440975189208984
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias_cst2init) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-MPC.optimize] start with 73 nodes
[GraphBuilder-MPC.optimize] #patterns=110
[GraphBuilder-MPC.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-MPC.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-MPC.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-MPC.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-MPC.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-MPC.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-MPC.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-MPC.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-MPC.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-MPC.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-MPC.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-MPC.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 46/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 47/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 48/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 49/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 50/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 51/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 52/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 53/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 54/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 55/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 56/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 57/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 58/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 59/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 60/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 61/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 62/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 63/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 64/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 65/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 66/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 67/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 68/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 69/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 70/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 71/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 72/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 73/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 74/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 75/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 76/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 77/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 78/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 79/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 80/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 81/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 82/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 83/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 84/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 85/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 86/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 87/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 88/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 89/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 90/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 91/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 92/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 93/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 94/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 95/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 96/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 97/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 98/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 99/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 100/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 101/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 102/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 103/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 104/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 105/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 106/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 107/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 108/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 109/110 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-MPC.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-MPC.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-MPC.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.009 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilder-MPC.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-MPC.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-MPC.optimize] increase priority to 1
[GraphBuilderPatternOptimization-MPC.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-MPC.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.007 | max_time=GeluOrtPattern:0.000
[GraphBuilder-MPC.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-MPC.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-MPC.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-MPC.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-MPC.optimize] increase priority to 2
[GraphBuilderPatternOptimization-MPC.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-MPC.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.007 | max_time=IdentityPattern:0.001
[GraphBuilder-MPC.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-MPC.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-MPC.optimize] increase priority to 3
[GraphBuilderPatternOptimization-MPC.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-MPC.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-MPC.optimize] done after 8 iterations with 29 nodes in 0.062
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002436219983792398
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.001055629996699281
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00028470299730543047
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0006414859999495093
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0002723729958233889
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0015001110004959628
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011994099986623041
STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.0631003533490002e-05
STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007940730065456592
STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.000843846002680948
STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007394740023300983
STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007384079981420655
STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0011613660062721465
STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.010924564998276765
STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0034808639975381084
STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009606732000975171
STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0055053940013749525
STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009232807999069337
STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009346571001515258
STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007278225999471033
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00030220799817470834
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023898099971120246
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024219399711000733
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022538299526786432
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005251700022199657
STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002399180011707358
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019194499691366218
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004212729945720639
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00026861899823416024
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017949100219993852
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002571740042185411
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003091939943260513
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00024347499493160285
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021676699907402508
STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018205300511908717
STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023227000565384515
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=9.518300066702068e-05
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000253029000305105
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015309299487853423
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001718380044621881
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023593799778609537
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002704960024857428
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017089699758798815
STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032928699874901213
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043007500426028855
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019621199317043647
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021913900127401575
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019528599659679458
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00011403099415474571
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00019527900076354854
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004933470045216382
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.759000098099932e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.633699715370312e-05
STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002442819968564436
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0026402769981359597
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0034772930048347916
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.499991963617504e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002626389978104271
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0036402740006451495
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00035562499760999344
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020553099966491573
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0036149469997326378
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.949299833853729e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009024710016092286
STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017280600150115788
STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015789300960022956
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005358339949452784
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025308700423920527
STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001983840011234861
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024312499590450898
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015967600484145805
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020447399583645165
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024331200256710872
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017762599964044057
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004994959999748971
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.1007002184633166e-05
STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7934998797718436e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00037027399594080634
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024220799969043583
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003843319973384496
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026758600142784417
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025412999821128324
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031544900048174895
STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000556958002562169
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009793869976419955
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016833399786264636
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002430500026093796
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002759019989753142
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006391040078597143
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005725939990952611
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025077899408643134
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044415500087779947
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003154839978378732
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005192650023673195
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003574580005079042
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002703989994188305
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004020529995614197
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033265500678680837
STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026588600303512067
STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002063550055027008
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031777500043972395
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00020218500139890239
STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001938030000019353
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014733300122315995
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015989899839041755
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001624830038053915
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004558331002044724
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016203400082304142
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016109299394884147
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003721749999385793
STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002465739999024663
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000236893003602745
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002213320003647823
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018208700566901825
STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025003900009323843
STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032406800164608285
STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002539649976824876
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047074999747565016
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002660249992914032
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002676740004972089
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011481599722173996
STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026388999685877934
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00038005400347174145
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004045329951622989
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026186600007349625
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00045842300096410327
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003965280011470895
STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002622110005177092
STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002332070034753997
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023478899674955755
STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=5.5829001212259755e-05
STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002393529997789301
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0025612999925215263
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-MPC.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00034933200004161336s with changed=0 scale=0
[GraphBuilder-MPC.optimize] done with 29 nodes in 0.073
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias) -> attention
SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-FOA.optimize] start with 73 nodes
[GraphBuilder-FOA.optimize] #patterns=110
[GraphBuilder-FOA.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-FOA.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-FOA.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-FOA.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-FOA.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-FOA.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-FOA.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-FOA.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-FOA.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-FOA.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-FOA.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-FOA.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 46/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 47/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 48/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 49/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 50/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 51/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 52/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 53/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 54/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 55/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 56/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 57/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 58/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 59/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 60/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 61/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 62/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 63/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 64/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 65/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 66/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 67/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 68/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 69/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 70/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 71/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 72/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 73/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 74/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 75/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 76/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 77/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 78/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 79/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 80/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 81/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 82/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 83/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 84/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 85/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 86/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 87/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 88/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 89/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 90/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 91/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 92/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 93/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 94/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 95/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 96/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 97/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 98/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 99/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 100/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 101/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 102/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 103/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 104/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 105/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 106/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 107/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 108/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 109/110 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-FOA.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-FOA.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-FOA.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.008 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilder-FOA.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-FOA.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-FOA.optimize] increase priority to 1
[GraphBuilderPatternOptimization-FOA.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-FOA.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-FOA.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-FOA.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-FOA.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-FOA.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-FOA.optimize] increase priority to 2
[GraphBuilderPatternOptimization-FOA.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-FOA.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-FOA.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-FOA.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-FOA.optimize] increase priority to 3
[GraphBuilderPatternOptimization-FOA.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-FOA.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-FOA.optimize] done after 8 iterations with 29 nodes in 0.055
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002181369964091573
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007028329964668956
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0003000720025738701
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0007817269979568664
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00026105700453626923
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.001309169991145609
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011582499791984446
STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=9.866991604212672e-06
STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008425270061707124
STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006346559966914356
STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006706179992761463
STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006892230012454093
STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0009456760017201304
STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0097180170014326
STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0032827739996719174
STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007376120996923419
STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005867957999726059
STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.010904203998507
STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0066166539982077666
STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005185133002669318
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00030789999800617807
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00022889200045028701
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016377499923692085
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003267720021540299
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004484860000957269
STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002589309951872565
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017885099805425853
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003952940023737028
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00024390999897150323
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015442700168932788
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023573599173687398
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002687539999897126
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022327400074573234
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001801959988370072
STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015055299809318967
STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001506930057075806
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.587900108774193e-05
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00021298100182320923
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014216300041880459
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016015299843274988
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002688210006454028
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001510740039520897
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016074200175353326
STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003547100022842642
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022341599105857313
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018347299555898644
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014886899953125976
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001650890008022543
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.0002868170013243798
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00015636600073776208
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004014039986941498
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.942399992723949e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.277499753399752e-05
STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020329499602667056
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002478865993907675
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003275160001066979
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.413004252361134e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015451700164703652
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0029474940020008944
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00033269300547544844
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020108699391130358
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003262512997025624
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.848300083423965e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000729836003301898
STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016591500389040448
STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014830600048298948
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00036613499469240196
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020516600125120021
STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018753900440060534
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002622419997351244
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001616139998077415
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014659300359198824
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020128100004512817
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014963000285206363
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044195000009494834
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.6620000426191837e-05
STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.711500201257877e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003197980004188139
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029100999745423906
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028975700479350053
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029472999449353665
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020567100000334904
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001450729978387244
STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004614160061464645
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009011590045702178
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001538010037620552
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015636100215488113
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023870299628470093
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044036399413016625
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039383599869324826
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020285500067984685
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000424265002948232
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002908349997596815
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004080829967278987
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030545599656761624
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031663600384490564
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003778260033868719
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025586899937479757
STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024193900026148185
STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003787440073210746
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026787999377120286
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00020831299480050802
STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016877999951248057
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014769299741601571
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014656299754278734
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001638830071897246
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003984964998380747
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014895300409989432
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015272699602064677
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033974399775615893
STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002061180020973552
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021407200256362557
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020405499526532367
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017039800150087103
STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021828199896845035
STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029430000722641125
STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023592899742652662
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044643199726124294
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024489700444974005
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023622100343345664
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011218399959034286
STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023358099861070514
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035592299900599755
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003567100029613357
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023944900749484077
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023346799571299925
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028098400434828363
STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020742199922096916
STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020266600404283963
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002037430058408063
STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=5.840800440637395e-05
STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002167069000279298
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.002404786995612085
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-FOA.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.0003004520003742073s with changed=0 scale=0
[GraphBuilder-FOA.optimize] done with 29 nodes in 0.066
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
res1, res2, align, dc = compare_onnx_execution(
onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 - | INITIA int64 1:1 BAAA in |
004 - | INITIA int64 1:1 AAAA in |
005 - | INITIA int64 1:1 EAAA in |
006 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 2:16x16 AZZA p_
007 ~ | INITIA float32 2:16x16 AZZA p_ | INITIA float32 2:16x16 CZZA p_
008 ~ | INITIA float32 2:16x16 CZZA p_ | INITIA float32 2:16x16 CACZ p_
009 ~ | INITIA float32 2:16x16 CACZ p_ | INITIA float32 2:30x30 KGSP sl
010 = | INITIA float32 1:1 AAAA in | INITIA float32 1:1 AAAA in
011 ~ | INITIA float32 1:1 AAAA in | INITIA float32 2:16x16 AZAA p_
012 ~ | INITIA float32 2:16x16 AZAA p_ | INITIA float32 2:16x16 AAAB p_
013 ~ | INITIA float32 2:16x16 AAAB p_ | INITIA float32 2:16x16 AABA p_
014 ~ | INITIA float32 2:16x16 AABA p_ | INITIA float32 2:30x30 KGSP sl
015 = | INITIA float32 2:32x16 ZZAA p_ | INITIA float32 2:32x16 ZZAA p_
016 = | INITIA float32 2:16x128 BCAG p_ | INITIA float32 2:16x128 BCAG p_
017 = | INITIA float32 2:128x16 AZAB p_ | INITIA float32 2:128x16 AZAB p_
018 = | INITIA float32 1:16 EEEE in | INITIA float32 1:16 EEEE in
019 = | INITIA float32 1:16 AAAA in | INITIA float32 1:16 AAAA in
020 = | INITIA float32 2:1024x16 HGMR em | INITIA float32 2:1024x16 HGMR em
021 = | INITIA float32 2:1024x16 DYWT em | INITIA float32 2:1024x16 DYWT em
022 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
023 = | INITIA float32 1:128 ABAA de | INITIA float32 1:128 ABAA de
024 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
025 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
026 - | RESULT int64 1:2 ABAA Concat Sl |
027 - | RESULT int64 1:2 EEAA Concat Sl |
028 - | RESULT int64 1:2 AAAA Concat Sl |
029 = | RESULT float32 3:1x30x16 JYWE Gather em | RESULT float32 3:1x30x16 JYWE Gather em
030 = | RESULT float32 3:1x30x16 OVZX Gather em | RESULT float32 3:1x30x16 OVZX Gather em
031 ~ | RESULT float32 3:1x30x16 XTVB Add ad | RESULT float32 3:1x30x16 BZYC SkipLayerNormal _o
032 ~ | RESULT float32 3:1x30x16 BZYC LayerNormalizat _o | RESULT float32 3:1x30x1 AAAB SkipLayerNormal un
033 ~ | RESULT float32 3:1x30x16 AFYA MatMul li | RESULT float32 3:1x30x1 GFFE SkipLayerNormal un
034 ~ | RESULT float32 3:1x30x16 GESE MatMul li | RESULT float32 3:1x30x16 XTVB SkipLayerNormal ad
035 ~ | RESULT float32 3:1x30x16 HADC MatMul li | RESULT float32 3:1x30x16 AFYA MatMul li
036 ~ | RESULT float32 3:1x16x30 CCFX Transpose tr | RESULT float32 3:1x30x16 GESE MatMul li
037 ~ | RESULT float32 3:1x30x30 OXVD MatMul ma | RESULT float32 3:1x30x30 DAFA FusedMatMul _o
038 ~ | RESULT float32 3:1x30x30 DAFA Mul _o | RESULT float32 3:1x30x16 HADC MatMul li
039 - | RESULT float32 2:30x30 KGSP Slice sl |
040 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
041 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
042 = | RESULT float32 3:1x30x30 IHHH Softmax so | RESULT float32 3:1x30x30 IHHH Softmax so
043 = | RESULT float32 3:1x30x16 LEED MatMul ma | RESULT float32 3:1x30x16 LEED MatMul ma
044 = | RESULT float32 3:1x30x16 AFAY MatMul li | RESULT float32 3:1x30x16 AFAY MatMul li
045 = | RESULT float32 3:1x30x16 AZVR MatMul li | RESULT float32 3:1x30x16 AZVR MatMul li
046 ~ | RESULT float32 3:1x30x16 WGBU MatMul li | RESULT float32 3:1x30x30 YWAA FusedMatMul _o
047 ~ | RESULT float32 3:1x16x30 YPVD Transpose tr | RESULT float32 3:1x30x16 WGBU MatMul li
048 ~ | RESULT float32 3:1x30x30 PHXB MatMul ma | RESULT bool 2:30x30 HLZC Equal eq
049 ~ | RESULT float32 3:1x30x30 YWAA Mul _o | RESULT float32 3:1x30x30 ???? Where ma
050 ~ | RESULT float32 2:30x30 KGSP Slice sl | RESULT float32 3:1x30x30 IGHH Softmax so
051 - | RESULT bool 2:30x30 HLZC Equal eq |
052 ~ | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x16 CZAA MatMul ma
053 ~ | RESULT float32 3:1x30x30 IGHH Softmax so | RESULT float32 3:1x30x32 ODED Concat ca
054 ~ | RESULT float32 3:1x30x16 CZAA MatMul ma | RESULT float32 3:1x30x16 CAAA MatMul _o
055 ~ | RESULT float32 3:1x30x32 ODED Concat ca | RESULT float32 3:1x30x16 HFDE Add li
056 ~ | RESULT float32 3:1x30x16 CAAA MatMul _o | RESULT float32 3:1x30x16 AAYC SkipLayerNormal _o
057 ~ | RESULT float32 3:1x30x16 HFDE Add li | RESULT float32 3:1x30x1 AAAB SkipLayerNormal un
058 ~ | RESULT float32 3:1x30x16 DYZG Add ad | RESULT float32 3:1x30x1 GFFE SkipLayerNormal un
059 ~ | RESULT float32 3:1x30x16 AAYC LayerNormalizat _o | RESULT float32 3:1x30x16 DYZG SkipLayerNormal ad
060 = | RESULT float32 3:1x30x128 PQAP MatMul _o | RESULT float32 3:1x30x128 PQAP MatMul _o
061 = | RESULT float32 3:1x30x128 BBMB Add li | RESULT float32 3:1x30x128 BBMB Add li
062 = | RESULT float32 3:1x30x128 VIUL Relu re | RESULT float32 3:1x30x128 VIUL Relu re
063 = | RESULT float32 3:1x30x16 YXAA MatMul _o | RESULT float32 3:1x30x16 YXAA MatMul _o
064 = | RESULT float32 3:1x30x16 ZXAA Add li | RESULT float32 3:1x30x16 ZXAA Add li
065 = | RESULT float32 3:1x30x16 BVAH Add ou | RESULT float32 3:1x30x16 BVAH Add ou
066 = | OUTPUT float32 3:1x30x16 BVAH ou | OUTPUT float32 3:1x30x16 BVAH ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 2.205 seconds)
Related examples
to_onnx and padding one dimension to a mulitple of a constant
to_onnx and a custom operator registered with a function