Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
_B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.227273941040039, max=3.6770823001861572
First conversion to ONNX¶
The conversion relies on torch.export.export().
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias, 1e-05, False), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias, 1e-05, False), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-4.227273941040039, max=3.6770823001861572
max discrepancy=2.384185791015625e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias_cst2init) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-4.227273941040039, max=3.6770823001861572
max discrepancy=2.384185791015625e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias_cst2init) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-OFO.optimize] start with 73 nodes
[GraphBuilder-OFO.optimize] #patterns=111
[GraphBuilder-OFO.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-OFO.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-OFO.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-OFO.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-OFO.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-OFO.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-OFO.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-OFO.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-OFO.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-OFO.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-OFO.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-OFO.optimize] start with 53 nodes, 28 initializers, 111 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 1/111 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 2/111 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 3/111 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 4/111 - P0 - CastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 5/111 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 6/111 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 7/111 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 8/111 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 9/111 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 10/111 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 11/111 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 12/111 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 13/111 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 14/111 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 15/111 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 16/111 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 17/111 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 18/111 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 19/111 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 20/111 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 21/111 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 22/111 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 23/111 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 24/111 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 25/111 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 26/111 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 27/111 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 28/111 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 29/111 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 30/111 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 31/111 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 32/111 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 33/111 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 34/111 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 35/111 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 36/111 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 37/111 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 38/111 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 39/111 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 40/111 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 41/111 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 42/111 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 43/111 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 44/111 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 45/111 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 46/111 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 47/111 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 48/111 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 49/111 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 50/111 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 51/111 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 52/111 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 53/111 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 54/111 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 55/111 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 56/111 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 57/111 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 58/111 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 59/111 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 60/111 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 61/111 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 62/111 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 63/111 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 64/111 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 65/111 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 66/111 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 67/111 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 68/111 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 69/111 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 70/111 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 71/111 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 72/111 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 73/111 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 74/111 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 75/111 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 76/111 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 77/111 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 78/111 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 79/111 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 80/111 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 81/111 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 82/111 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 83/111 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 84/111 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 85/111 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 86/111 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 87/111 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 88/111 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 89/111 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 90/111 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 91/111 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 92/111 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 93/111 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 94/111 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 95/111 - P1 - SwapRangeAddScalarPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 96/111 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 97/111 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 98/111 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 99/111 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 100/111 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 101/111 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 102/111 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 103/111 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 104/111 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 105/111 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 106/111 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 107/111 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 108/111 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 109/111 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 110/111 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 111/111 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-OFO.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-OFO.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-OFO.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.010 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilder-OFO.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-OFO.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-OFO.optimize] increase priority to 1
[GraphBuilderPatternOptimization-OFO.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-OFO.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.007 | max_time=IdentityPattern:0.000
[GraphBuilder-OFO.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-OFO.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-OFO.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-OFO.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-OFO.optimize] increase priority to 2
[GraphBuilderPatternOptimization-OFO.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-OFO.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-OFO.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-OFO.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-OFO.optimize] increase priority to 3
[GraphBuilderPatternOptimization-OFO.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-OFO.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-OFO.optimize] done after 8 iterations with 29 nodes in 0.068
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002796830000306727
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0008231299999579278
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00040775000002213346
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0008430249999946682
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0003837829999611131
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014849640001557418
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00015791499993156322
STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.2842999922213494e-05
STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008780440000464296
STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007604280000350627
STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007499829998778296
STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007617989999744168
STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0011700530001235165
STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012680772000067009
STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0043271060000051875
STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009838464999916141
STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.008060679999971399
STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012667783999972926
STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007314950000022691
STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.006530704999931913
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0004188979997934439
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.000340633000064372
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002591709998114311
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003446879999273733
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005575349999844548
STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003322699999444012
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002226780000000872
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000522155999988172
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00030480999998871994
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019006400020771252
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00027801299995644513
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00032002899979488575
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002825740000389487
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023173900001438597
STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017946199977814103
STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024100400014503975
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00012077899998530484
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00027817299985599675
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001635399999031506
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018079200015108654
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00026579300015328045
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020278800002415664
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002538889998504601
STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00038194100000055187
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031582799988427723
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022269099997629382
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002135419999831356
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002463259999103684
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.0001035910000837248
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0002082470001596448
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0005211210000197752
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010804199996528041
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.274899994376028e-05
STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002508210000087274
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0031837300001598123
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00400782200017602
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=1.4601000088987348e-05
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018471200007752486
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.003320490999840331
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00035135500013439014
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022237500002120214
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003458630000068297
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.07420000228376e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008081290000063746
STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026140299985399906
STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021998199997597112
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043513500008884876
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002459989999579193
STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029559799997969094
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003301319998172403
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002050850000614446
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024473600012697716
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002501779999874998
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018214800002169795
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006007420001878927
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7806999923996045e-05
STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.8638000003411435e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039604199992027134
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027059399997142464
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039018499978737964
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002764679998108477
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002886350001745086
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020024400009788224
STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006482430000005479
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0012550730000384647
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019046899990371458
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019230600014452648
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002702659999158641
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005534680001346715
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004929219999212364
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002562129998295859
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00048176499979035725
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003040590003138277
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000511414000015975
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003237330001866212
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000301928000112639
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00036799400015752326
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002717099998790218
STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002654090000078213
STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003216129999827899
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039327800016053516
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003443280002102256
STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002368949999436154
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021866199983833212
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023388499994325684
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000198361999878216
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005661125999949945
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020763200006967963
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002152060001208156
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005040850002160369
STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002940769999213444
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003017619998217924
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029393400006938464
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023319600006743713
STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028877600004761916
STAT match_SwapRangeAddScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020586899995578278
STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000501877000147033
STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004460769998786418
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006232089999684831
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033637699982591585
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003299620001371295
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00013503599996056437
STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00036288399985551223
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000487903000021106
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000557096000193269
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000331902000084483
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031846400008817
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004097160001492739
STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000298998999937794
STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031692700019902986
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028527600022698607
STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=6.550699993113085e-05
STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002454932999967241
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0023611650001384987
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-OFO.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00024625599996852543s with changed=0 scale=0
[GraphBuilder-OFO.optimize] done with 29 nodes in 0.081
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias) -> attention
SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-BIM.optimize] start with 73 nodes
[GraphBuilder-BIM.optimize] #patterns=111
[GraphBuilder-BIM.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-BIM.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-BIM.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-BIM.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-BIM.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-BIM.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-BIM.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-BIM.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-BIM.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-BIM.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-BIM.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-BIM.optimize] start with 53 nodes, 28 initializers, 111 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 1/111 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 2/111 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 3/111 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 4/111 - P0 - CastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 5/111 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 6/111 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 7/111 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 8/111 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 9/111 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 10/111 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 11/111 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 12/111 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 13/111 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 14/111 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 15/111 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 16/111 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 17/111 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 18/111 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 19/111 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 20/111 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 21/111 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 22/111 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 23/111 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 24/111 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 25/111 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 26/111 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 27/111 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 28/111 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 29/111 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 30/111 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 31/111 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 32/111 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 33/111 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 34/111 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 35/111 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 36/111 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 37/111 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 38/111 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 39/111 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 40/111 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 41/111 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 42/111 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 43/111 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 44/111 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 45/111 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 46/111 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 47/111 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 48/111 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 49/111 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 50/111 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 51/111 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 52/111 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 53/111 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 54/111 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 55/111 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 56/111 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 57/111 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 58/111 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 59/111 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 60/111 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 61/111 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 62/111 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 63/111 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 64/111 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 65/111 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 66/111 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 67/111 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 68/111 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 69/111 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 70/111 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 71/111 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 72/111 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 73/111 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 74/111 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 75/111 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 76/111 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 77/111 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 78/111 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 79/111 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 80/111 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 81/111 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 82/111 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 83/111 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 84/111 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 85/111 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 86/111 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 87/111 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 88/111 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 89/111 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 90/111 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 91/111 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 92/111 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 93/111 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 94/111 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 95/111 - P1 - SwapRangeAddScalarPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 96/111 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 97/111 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 98/111 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 99/111 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 100/111 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 101/111 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 102/111 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 103/111 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 104/111 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 105/111 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 106/111 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 107/111 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 108/111 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 109/111 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 110/111 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 111/111 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-BIM.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-BIM.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-BIM.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.014 | max_time=GeluOrtPattern:0.003
[GraphBuilder-BIM.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-BIM.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-BIM.optimize] increase priority to 1
[GraphBuilderPatternOptimization-BIM.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-BIM.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.011 | max_time=IdentityPattern:0.001
[GraphBuilder-BIM.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-BIM.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-BIM.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.008 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization-BIM.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-BIM.optimize] increase priority to 2
[GraphBuilderPatternOptimization-BIM.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-BIM.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilder-BIM.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-BIM.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-BIM.optimize] increase priority to 3
[GraphBuilderPatternOptimization-BIM.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-BIM.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-BIM.optimize] done after 8 iterations with 29 nodes in 0.083
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0003927450001128818
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0008372890000600819
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0005206710000038584
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0008976920000804967
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0005576270000346994
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0024291549998451956
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00019574699990698718
STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.4463999832514673e-05
STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010345870000492141
STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008410969999204099
STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010721369999373564
STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0012714620003180244
STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0013911949999965145
STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.017398427999978594
STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005836232000092423
STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.014595395000014832
STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.010892864999959784
STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.013563587000021471
STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00798747600003935
STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.006097939000028418
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005423460000884006
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0003526250000049913
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021156300010716222
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00048122600003353
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0010116020000623394
STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000383157000214851
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00030494800000724354
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0006804939999938142
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00040012800013755623
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000268471999902431
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005747600000631792
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004547399997818502
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00034406099996431294
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00027270700002191006
STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002884139998968749
STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023516000010204152
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.0001008420001653576
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00035215200000493496
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021424800013392087
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00024107000012918434
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00034454600006483815
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023592499996993865
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019912000004751462
STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004915349996963414
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002977619999455783
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028267300012885244
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021592700011296984
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021003400013341889
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.476300001802883e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00019111500000690285
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.000505160999864529
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.24880000259509e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.090999997672043e-05
STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003279179999253756
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004492137999932311
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005713610999919183
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=2.005099997859361e-05
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001923699999224482
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0046795250000286615
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0004254409998338815
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022064900008444965
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004560923000212824
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.416199989369488e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009771130002036443
STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021876400001019647
STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023216500028411247
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047343099981844716
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028930899998158566
STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002252620001854666
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009103800001639684
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021284299987200939
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019017100021301303
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030792200016094284
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021390700021584053
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007301869998173061
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.6609999963511655e-05
STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7394999960961286e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004362900000387526
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029923699992195907
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000371487000165871
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028750000012678356
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005287090002639161
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017884500005038717
STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005873640001254898
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0011766749997832449
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019239600021592196
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019355699998868658
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029006000011122524
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005453239999724246
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005247369999779039
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002285820000906824
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005957569999281986
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032299300005433906
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006123439998191316
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032893900026920164
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028429400015284045
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00037383399978807574
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003264350001472849
STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002826959997719314
STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002345729999433388
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004291999998713436
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002579660001629236
STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024846100006925553
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023452700008874672
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018373000000337925
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001962640000101601
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00596758300014244
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019786300003943325
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024092799992558867
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007429560000673519
STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00037768300001062016
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034901799995168403
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003510870001264266
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002302809999719102
STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003316210001003128
STAT match_SwapRangeAddScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019076599994605203
STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000513451999836434
STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004026749999184176
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006422130001055848
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035893599999781145
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003505620001078569
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00023595400000431255
STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003478300000097079
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000461320000113119
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00045558099998288526
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00041218200021830853
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003610889998526545
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004339389998904153
STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034320500003559573
STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033947700012504356
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004482400000824782
STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=7.52230002944998e-05
STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.003287892000003012
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0033581790000880574
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-BIM.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00022337699999752658s with changed=0 scale=0
[GraphBuilder-BIM.optimize] done with 29 nodes in 0.095
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
res1, res2, align, dc = compare_onnx_execution(
onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 - | INITIA int64 1:1 BAAA in |
004 - | INITIA int64 1:1 AAAA in |
005 - | INITIA int64 1:1 EAAA in |
006 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 2:16x16 ZAAA p_
007 ~ | INITIA float32 2:16x16 ZAAA p_ | INITIA float32 2:16x16 AABA p_
008 ~ | INITIA float32 2:16x16 AABA p_ | INITIA float32 2:16x16 AAAZ p_
009 ~ | INITIA float32 2:16x16 AAAZ p_ | INITIA float32 2:30x30 KGSP sl
010 = | INITIA float32 1:1 AAAA in | INITIA float32 1:1 AAAA in
011 ~ | INITIA float32 1:1 AAAA in | INITIA float32 2:16x16 AAAZ p_
012 ~ | INITIA float32 2:16x16 AAAZ p_ | INITIA float32 2:16x16 AYAA p_
013 ~ | INITIA float32 2:16x16 AYAA p_ | INITIA float32 2:16x16 AZAA p_
014 ~ | INITIA float32 2:16x16 AZAA p_ | INITIA float32 2:30x30 KGSP sl
015 = | INITIA float32 2:32x16 BCAA p_ | INITIA float32 2:32x16 BCAA p_
016 = | INITIA float32 2:16x128 AEBU p_ | INITIA float32 2:16x128 AEBU p_
017 = | INITIA float32 2:128x16 AZAY p_ | INITIA float32 2:128x16 AZAY p_
018 = | INITIA float32 1:16 EEEE in | INITIA float32 1:16 EEEE in
019 = | INITIA float32 1:16 AAAA in | INITIA float32 1:16 AAAA in
020 = | INITIA float32 2:1024x16 XLXT em | INITIA float32 2:1024x16 XLXT em
021 = | INITIA float32 2:1024x16 KQDM em | INITIA float32 2:1024x16 KQDM em
022 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
023 = | INITIA float32 1:128 AAAA de | INITIA float32 1:128 AAAA de
024 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
025 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
026 - | RESULT int64 1:2 ABAA Concat Sl |
027 - | RESULT int64 1:2 EEAA Concat Sl |
028 - | RESULT int64 1:2 AAAA Concat Sl |
029 = | RESULT float32 3:1x30x16 FTQA Gather em | RESULT float32 3:1x30x16 FTQA Gather em
030 = | RESULT float32 3:1x30x16 JHDX Gather em | RESULT float32 3:1x30x16 JHDX Gather em
031 ~ | RESULT float32 3:1x30x16 OATX Add ad | RESULT float32 3:1x30x16 AAXD SkipLayerNormal _o
032 ~ | RESULT float32 3:1x30x16 AAXD LayerNormalizat _o | RESULT float32 3:1x30x1 BZBA SkipLayerNormal un
033 ~ | RESULT float32 3:1x30x16 WGBZ MatMul li | RESULT float32 3:1x30x1 GFFE SkipLayerNormal un
034 ~ | RESULT float32 3:1x30x16 GEVQ MatMul li | RESULT float32 3:1x30x16 OATX SkipLayerNormal ad
035 ~ | RESULT float32 3:1x30x16 WYCX MatMul li | RESULT float32 3:1x30x16 WGBZ MatMul li
036 ~ | RESULT float32 3:1x16x30 XSDD Transpose tr | RESULT float32 3:1x30x16 GEVQ MatMul li
037 ~ | RESULT float32 3:1x30x30 UAMZ MatMul ma | RESULT float32 3:1x30x30 ZAXA FusedMatMul _o
038 ~ | RESULT float32 3:1x30x30 ZAXA Mul _o | RESULT float32 3:1x30x16 WYCX MatMul li
039 - | RESULT float32 2:30x30 KGSP Slice sl |
040 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
041 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
042 = | RESULT float32 3:1x30x30 HGHH Softmax so | RESULT float32 3:1x30x30 HGHH Softmax so
043 = | RESULT float32 3:1x30x16 ZXXZ MatMul ma | RESULT float32 3:1x30x16 ZXXZ MatMul ma
044 = | RESULT float32 3:1x30x16 YFZA MatMul li | RESULT float32 3:1x30x16 YFZA MatMul li
045 = | RESULT float32 3:1x30x16 AACY MatMul li | RESULT float32 3:1x30x16 AACY MatMul li
046 ~ | RESULT float32 3:1x30x16 XFDG MatMul li | RESULT float32 3:1x30x30 YABZ FusedMatMul _o
047 ~ | RESULT float32 3:1x16x30 XGCV Transpose tr | RESULT float32 3:1x30x16 XFDG MatMul li
048 ~ | RESULT float32 3:1x30x30 PDHV MatMul ma | RESULT bool 2:30x30 HLZC Equal eq
049 ~ | RESULT float32 3:1x30x30 YABZ Mul _o | RESULT float32 3:1x30x30 ???? Where ma
050 ~ | RESULT float32 2:30x30 KGSP Slice sl | RESULT float32 3:1x30x30 IHHH Softmax so
051 - | RESULT bool 2:30x30 HLZC Equal eq |
052 ~ | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x16 TZEE MatMul ma
053 ~ | RESULT float32 3:1x30x30 IHHH Softmax so | RESULT float32 3:1x30x32 RWBC Concat ca
054 ~ | RESULT float32 3:1x30x16 TZEE MatMul ma | RESULT float32 3:1x30x16 DZAA MatMul _o
055 ~ | RESULT float32 3:1x30x32 RWBC Concat ca | RESULT float32 3:1x30x16 AWWX Add li
056 ~ | RESULT float32 3:1x30x16 DZAA MatMul _o | RESULT float32 3:1x30x16 ZBXD SkipLayerNormal _o
057 ~ | RESULT float32 3:1x30x16 AWWX Add li | RESULT float32 3:1x30x1 BYBA SkipLayerNormal un
058 ~ | RESULT float32 3:1x30x16 OVPU Add ad | RESULT float32 3:1x30x1 GFFE SkipLayerNormal un
059 ~ | RESULT float32 3:1x30x16 ZBXD LayerNormalizat _o | RESULT float32 3:1x30x16 OVPU SkipLayerNormal ad
060 = | RESULT float32 3:1x30x128 LIFE MatMul _o | RESULT float32 3:1x30x128 LIFE MatMul _o
061 = | RESULT float32 3:1x30x128 GFBA Add li | RESULT float32 3:1x30x128 GFBA Add li
062 = | RESULT float32 3:1x30x128 WBAP Relu re | RESULT float32 3:1x30x128 WBAP Relu re
063 = | RESULT float32 3:1x30x16 MTTU MatMul _o | RESULT float32 3:1x30x16 MTTU MatMul _o
064 = | RESULT float32 3:1x30x16 NUUU Add li | RESULT float32 3:1x30x16 NUUU Add li
065 = | RESULT float32 3:1x30x16 APIN Add ou | RESULT float32 3:1x30x16 APIN Add ou
066 = | OUTPUT float32 3:1x30x16 APIN ou | OUTPUT float32 3:1x30x16 APIN ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 2.687 seconds)
Related examples
to_onnx and padding one dimension to a mulitple of a constant
to_onnx and a custom operator registered with a function