Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
_B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.333163261413574, max=4.399672031402588
First conversion to ONNX¶
The conversion relies on torch.export.export().
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-4.333163261413574, max=4.399672031402588
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias_cst2init) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-4.333163261413574, max=4.399672031402588
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,) -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias_cst2init) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-HLO.optimize] start with 73 nodes
[GraphBuilder-HLO.optimize] #patterns=110
[GraphBuilder-HLO.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-HLO.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-HLO.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-HLO.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-HLO.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-HLO.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-HLO.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-HLO.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-HLO.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-HLO.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-HLO.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-HLO.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 46/110 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 47/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 48/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 49/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 50/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 51/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 52/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 53/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 54/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 55/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 56/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 57/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 58/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 59/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 60/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 61/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 62/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 63/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 64/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 65/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 66/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 67/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 68/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 69/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 70/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 71/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 72/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 73/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 74/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 75/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 76/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 77/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 78/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 79/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 80/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 81/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 82/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 83/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 84/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 85/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 86/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 87/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 88/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 89/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 90/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 91/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 92/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 93/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 94/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 95/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 96/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 97/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 98/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 99/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 100/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 101/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 102/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 103/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 104/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 105/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 106/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 107/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 108/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 109/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-HLO.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-HLO.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-HLO.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-HLO.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-HLO.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-HLO.optimize] increase priority to 1
[GraphBuilderPatternOptimization-HLO.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-HLO.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-HLO.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-HLO.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-HLO.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.005 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization-HLO.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-HLO.optimize] increase priority to 2
[GraphBuilderPatternOptimization-HLO.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-HLO.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-HLO.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-HLO.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-HLO.optimize] increase priority to 3
[GraphBuilderPatternOptimization-HLO.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-HLO.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-HLO.optimize] done after 8 iterations with 29 nodes in 0.054
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0001723969980957918
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.000673380996886408
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00023979999969014898
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004525479998847004
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0003544610008248128
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0013156780041754246
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011354800153640099
STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=9.716000931803137e-06
STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008246120050898753
STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007010170011199079
STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007421599984809291
STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007541799932369031
STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0009021700025186874
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002740159943641629
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023372700525214896
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017684999693301506
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017152099462691694
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004436009985511191
STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023066400171956047
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018826699670171365
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004079410027770791
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002456150032230653
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015757298751850612
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002946699969470501
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002486100020178128
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002692319903871976
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022198700025910512
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018808200547937304
STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015097799769137055
STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001672569997026585
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.977200195658952e-05
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022119199638837017
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001333680011157412
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001511329974164255
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00021671600552508608
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014730300244991668
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001703389971225988
STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003405250026844442
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026911100576398894
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000207890996534843
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018999300664290786
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018627900135470554
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.528699254384264e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0002006879985856358
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004500110007938929
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.68450008763466e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.222200267482549e-05
STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020293600027798675
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00264904199866578
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0029831009996996727
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.497001545038074e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016557799972360954
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0021669380039384123
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003258099968661554
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020132799909333698
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0025238369998987764
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.261599941761233e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007021879937383346
STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019630399765446782
STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016829699961817823
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00041063000026042573
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002414660011709202
STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021530300000449643
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026854799943976104
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001649089972488582
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026000799698522314
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024280999423353933
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017760900300345384
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005246460059424862
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=3.312199987703934e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003510160058795009
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023547700038761832
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003237529963371344
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022316700051305816
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031471100010094233
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015838200124562718
STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005291520028549712
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009448159980820492
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018283100507687777
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018946099589811638
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022426699797506444
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005329160012479406
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004506580007728189
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019979700664407574
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004055519966641441
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002698799944482744
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004581969988066703
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002892280062951613
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023894499827292748
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003005779981322121
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024243699954240583
STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022982299924478866
STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000219395002204692
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031030700120027177
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00019556899496819824
STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021693200324079953
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015458200141438283
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016941800276981667
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017050000315066427
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0037985969938745257
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016551199951209128
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016104599490063265
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003632650004874449
STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024043599842116237
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000257101004535798
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002237959997728467
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018160500258090906
STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021849099721293896
STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003226369990443345
STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00039116699917940423
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006148589964141138
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032628500048303977
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003146920025756117
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011871900278492831
STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003131949997623451
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005266170082904864
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047179899411275983
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028350500724627636
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027891799618373625
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003492720024951268
STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025233099586330354
STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002409049920970574
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002463310011080466
STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=4.880599590251222e-05
STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0019299200030218344
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0021111540008860175
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-HLO.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-HLO.optimize] done with 29 nodes in 0.064
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, weight::T105) -> _onx_matmul_cat
Add(_onx_matmul_cat, bias) -> attention
SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, weight::T1023) -> _onx_matmul_relu
Add(_onx_matmul_relu, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-EWM.optimize] start with 73 nodes
[GraphBuilder-EWM.optimize] #patterns=110
[GraphBuilder-EWM.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-EWM.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-EWM.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-EWM.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-EWM.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-EWM.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-EWM.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-EWM.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-EWM.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-EWM.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-EWM.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-EWM.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 46/110 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 47/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 48/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 49/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 50/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 51/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 52/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 53/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 54/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 55/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 56/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 57/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 58/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 59/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 60/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 61/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 62/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 63/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 64/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 65/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 66/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 67/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 68/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 69/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 70/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 71/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 72/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 73/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 74/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 75/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 76/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 77/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 78/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 79/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 80/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 81/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 82/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 83/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 84/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 85/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 86/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 87/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 88/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 89/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 90/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 91/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 92/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 93/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 94/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 95/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 96/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 97/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 98/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 99/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 100/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 101/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 102/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 103/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 104/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 105/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 106/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 107/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 108/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 109/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-EWM.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-EWM.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-EWM.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-EWM.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-EWM.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-EWM.optimize] increase priority to 1
[GraphBuilderPatternOptimization-EWM.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-EWM.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-EWM.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-EWM.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-EWM.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-EWM.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-EWM.optimize] increase priority to 2
[GraphBuilderPatternOptimization-EWM.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-EWM.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-EWM.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-EWM.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-EWM.optimize] increase priority to 3
[GraphBuilderPatternOptimization-EWM.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-EWM.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-EWM.optimize] done after 8 iterations with 29 nodes in 0.045
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00017572700016899034
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0004853729988099076
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00023937799414852634
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004224229996907525
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00016526099716429599
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010593710067041684
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010789400039357133
STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=8.25899769552052e-06
STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006228849924809765
STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005360539980756585
STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005590109976765234
STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005760540043411311
STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.00072660599835217
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023676199634792283
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00020293999841669574
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013246700109448284
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002626610003062524
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004166899998381268
STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00020839499484281987
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015502799942623824
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003525500032992568
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002206230019510258
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013575199409388006
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002454509995004628
STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018761200044536963
STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00026909499865723774
STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002003980043809861
STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017381199722876772
STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013258099716040306
STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013269899864098988
STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.581199704669416e-05
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001918070083775092
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011614200775511563
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012782899648300372
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000193311003386043
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012491100278566591
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013604500054498203
STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025412800459889695
STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020083299750695005
STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013776700143353082
STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016289999621221796
STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001309809995291289
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.145099541754462e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0001370050013065338
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00032459900103276595
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.578800144372508e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.085700053721666e-05
STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001712220000626985
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0022345819961628877
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0025460659962845966
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.695001841057092e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013169399971957318
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0019215859974792693
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002948989967990201
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001595130015630275
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0023374550000880845
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.9786001656902954e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005574090027948841
STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014472100156126544
STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013496300016413443
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003049280021514278
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017915699936565943
STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017910299720824696
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000239348995819455
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013305900210980326
STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013289100388647057
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017903599655255675
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013284999658935703
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003690820012707263
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.931899896590039e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028197900246595964
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021704999744542874
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025234700297005475
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019502999930409715
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001841830016928725
STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012774599963449873
STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004596960025082808
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007208260030893143
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013939599739387631
STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013823499830323271
STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001885819910967257
STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00042165600825683214
STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034922899794764817
STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016093799786176533
STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033534799149492756
STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021886500690015964
STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003706440002133604
STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023459699514205568
STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020034900080645457
STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000267596999037778
STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019412000619922765
STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019651100228657015
STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017006600683089346
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002447639999445528
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.000168531001691008
STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014786100291530602
STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012568600504891947
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013235599908512086
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016139199942699634
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0032920120065682568
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015577500380459242
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001337530011369381
STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033944000460905954
STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019519499619491398
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020287299776100554
STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019088300177827477
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015230500139296055
STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001900570023281034
STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002686159969016444
STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021872000070288777
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003819379890046548
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019452800552244298
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019614600023487583
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.107699977699667e-05
STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022033500135876238
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035078099608654156
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031080899861990474
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021419699987745844
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024124601259245537
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002268880016345065
STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001956129926838912
STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002117060030286666
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018981700122822076
STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=4.4573997001862153e-05
STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0016913559920794796
STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0026500370004214346
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 21 x 1t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 4 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 3 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-EWM.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-EWM.optimize] done with 29 nodes in 0.052
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
res1, res2, align, dc = compare_onnx_execution(
onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 - | INITIA int64 1:1 BAAA in |
004 - | INITIA int64 1:1 AAAA in |
005 - | INITIA int64 1:1 EAAA in |
006 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 2:16x16 AAAB p_
007 ~ | INITIA float32 2:16x16 AAAB p_ | INITIA float32 2:16x16 BAAB p_
008 ~ | INITIA float32 2:16x16 BAAB p_ | INITIA float32 2:16x16 AAAA p_
009 ~ | INITIA float32 2:16x16 AAAA p_ | INITIA float32 2:30x30 KGSP sl
010 = | INITIA float32 1:1 AAAA in | INITIA float32 1:1 AAAA in
011 ~ | INITIA float32 1:1 AAAA in | INITIA float32 2:16x16 BBAA p_
012 ~ | INITIA float32 2:16x16 BBAA p_ | INITIA float32 2:16x16 AAZA p_
013 ~ | INITIA float32 2:16x16 AAZA p_ | INITIA float32 2:16x16 ZAAA p_
014 ~ | INITIA float32 2:16x16 ZAAA p_ | INITIA float32 2:30x30 KGSP sl
015 = | INITIA float32 2:32x16 AZBC p_ | INITIA float32 2:32x16 AZBC p_
016 = | INITIA float32 2:16x128 CZEW p_ | INITIA float32 2:16x128 CZEW p_
017 = | INITIA float32 2:128x16 AAZA p_ | INITIA float32 2:128x16 AAZA p_
018 = | INITIA float32 1:16 EEEE in | INITIA float32 1:16 EEEE in
019 = | INITIA float32 1:16 AAAA in | INITIA float32 1:16 AAAA in
020 = | INITIA float32 2:1024x16 PHTW em | INITIA float32 2:1024x16 PHTW em
021 = | INITIA float32 2:1024x16 ILTN em | INITIA float32 2:1024x16 ILTN em
022 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
023 = | INITIA float32 1:128 AAZA de | INITIA float32 1:128 AAZA de
024 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
025 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
026 - | RESULT int64 1:2 ABAA Concat Sl |
027 - | RESULT int64 1:2 EEAA Concat Sl |
028 - | RESULT int64 1:2 AAAA Concat Sl |
029 = | RESULT float32 3:1x30x16 VPRM Gather em | RESULT float32 3:1x30x16 VPRM Gather em
030 = | RESULT float32 3:1x30x16 ILJE Gather em | RESULT float32 3:1x30x16 ILJE Gather em
031 ~ | RESULT float32 3:1x30x16 DAAQ Add ad | RESULT float32 3:1x30x16 ZBXD SkipLayerNormal _o
032 ~ | RESULT float32 3:1x30x16 ZBXD LayerNormalizat _o | RESULT float32 3:1x30x1 ZAZA SkipLayerNormal un
033 ~ | RESULT float32 3:1x30x16 OTTE MatMul li | RESULT float32 3:1x30x1 GFGE SkipLayerNormal un
034 ~ | RESULT float32 3:1x30x16 VDVA MatMul li | RESULT float32 3:1x30x16 DAAQ SkipLayerNormal ad
035 ~ | RESULT float32 3:1x30x16 VVDX MatMul li | RESULT float32 3:1x30x16 OTTE MatMul li
036 ~ | RESULT float32 3:1x16x30 BBOA Transpose tr | RESULT float32 3:1x30x16 VDVA MatMul li
037 ~ | RESULT float32 3:1x30x30 XIBS MatMul ma | RESULT float32 3:1x30x30 AWAS FusedMatMul _o
038 ~ | RESULT float32 3:1x30x30 AWAS Mul _o | RESULT float32 3:1x30x16 VVDX MatMul li
039 - | RESULT float32 2:30x30 KGSP Slice sl |
040 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
041 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
042 = | RESULT float32 3:1x30x30 HHHH Softmax so | RESULT float32 3:1x30x30 HHHH Softmax so
043 = | RESULT float32 3:1x30x16 UXWX MatMul ma | RESULT float32 3:1x30x16 UXWX MatMul ma
044 = | RESULT float32 3:1x30x16 GXXA MatMul li | RESULT float32 3:1x30x16 GXXA MatMul li
045 = | RESULT float32 3:1x30x16 IACC MatMul li | RESULT float32 3:1x30x16 IACC MatMul li
046 ~ | RESULT float32 3:1x30x16 XZZX MatMul li | RESULT float32 3:1x30x30 HUEC FusedMatMul _o
047 ~ | RESULT float32 3:1x16x30 BDYK Transpose tr | RESULT float32 3:1x30x16 XZZX MatMul li
048 ~ | RESULT float32 3:1x30x30 DZSJ MatMul ma | RESULT bool 2:30x30 HLZC Equal eq
049 ~ | RESULT float32 3:1x30x30 HUEC Mul _o | RESULT float32 3:1x30x30 ???? Where ma
050 ~ | RESULT float32 2:30x30 KGSP Slice sl | RESULT float32 3:1x30x30 IHHH Softmax so
051 - | RESULT bool 2:30x30 HLZC Equal eq |
052 ~ | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x16 AZZX MatMul ma
053 ~ | RESULT float32 3:1x30x30 IHHH Softmax so | RESULT float32 3:1x30x32 VVVU Concat ca
054 ~ | RESULT float32 3:1x30x16 AZZX MatMul ma | RESULT float32 3:1x30x16 ZWWX MatMul _o
055 ~ | RESULT float32 3:1x30x32 VVVU Concat ca | RESULT float32 3:1x30x16 AWXX Add li
056 ~ | RESULT float32 3:1x30x16 ZWWX MatMul _o | RESULT float32 3:1x30x16 ZBXD SkipLayerNormal _o
057 ~ | RESULT float32 3:1x30x16 AWXX Add li | RESULT float32 3:1x30x1 ZAYA SkipLayerNormal un
058 ~ | RESULT float32 3:1x30x16 DXWN Add ad | RESULT float32 3:1x30x1 GFGE SkipLayerNormal un
059 ~ | RESULT float32 3:1x30x16 ZBXD LayerNormalizat _o | RESULT float32 3:1x30x16 DXWN SkipLayerNormal ad
060 = | RESULT float32 3:1x30x128 GCKF MatMul _o | RESULT float32 3:1x30x128 GCKF MatMul _o
061 = | RESULT float32 3:1x30x128 BWFY Add li | RESULT float32 3:1x30x128 BWFY Add li
062 = | RESULT float32 3:1x30x128 AXWJ Relu re | RESULT float32 3:1x30x128 AXWJ Relu re
063 = | RESULT float32 3:1x30x16 BCAY MatMul _o | RESULT float32 3:1x30x16 BCAY MatMul _o
064 = | RESULT float32 3:1x30x16 ZAYV Add li | RESULT float32 3:1x30x16 ZAYV Add li
065 = | RESULT float32 3:1x30x16 BWTH Add ou | RESULT float32 3:1x30x16 BWTH Add ou
066 = | OUTPUT float32 3:1x30x16 BWTH ou | OUTPUT float32 3:1x30x16 BWTH ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 2.434 seconds)
Related examples
to_onnx and padding one dimension to a mulitple of a constant
to_onnx and a custom operator registered with a function