Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.806222438812256, max=4.5326056480407715
First conversion to ONNX¶
The conversion relies on torch.export.export()
.
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, None, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, None, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, None, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='_reshape_init1_s_0' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='_reshape_init1_s_02' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='SliceSlicePattern_init7_s1_0_start' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilder.constant_folding.from/fold(init7_s1_0)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilder.constant_folding.from/fold(init7_s1_30)##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_0_start2' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilder.constant_folding.from/fold(init7_s1_0)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end2' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilder.constant_folding.from/fold(init7_s1_30)##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis2' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add00
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, _reshape_init1_s_20) -> eq
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Mul(matmul_2, _reshape_init1_s_02) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_4
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, _onx_transpose_p_decoder_attention_linear_weight0) -> _onx_matmul_cat0
Add(_onx_matmul_cat0, decoder.attention.linear.bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_100
MatMul(_onx_div_sub_add_100, _onx_transpose_p_decoder_feed_forward_linear_1_weight0) -> _onx_matmul_layer_norm_10
Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, _onx_transpose_p_decoder_feed_forward_linear_2_weight0) -> _onx_matmul_relu0
Add(_onx_matmul_relu0, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-3.806222438812256, max=4.5326056480407715
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='_onx_transpose_weight05' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight05)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='_onx_transpose_weight06' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(_onx_transpose_weight06)
init: name='_onx_transpose_weight023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, _onx_transpose_weight0) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> _reshape_init1_s_0
Constant(value=[0.0]) -> _reshape_init1_s_20
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, _reshape_init1_s_20) -> eq
MatMul(norm_1, _onx_transpose_weight02) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, _onx_transpose_weight03) -> value
MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> _reshape_init1_s_02
Constant(value=[0.0]) -> _reshape_init1_s_202
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
Equal(slice_22, _reshape_init1_s_202) -> eq2
MatMul(norm_1, _onx_transpose_weight04) -> query2
MatMul(norm_1, _onx_transpose_weight022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, _reshape_init1_s_02) -> _onx_mul_matmul02
Where(eq2, init1_s1_2, _onx_mul_matmul02) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, _onx_transpose_weight032) -> value2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, _onx_transpose_weight05) -> _onx_matmul_cat0
Constant(value=[0.0656113...) -> bias
Add(_onx_matmul_cat0, bias) -> attention
Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, _onx_transpose_weight06) -> _onx_matmul_layer_norm_10
Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, _onx_transpose_weight023) -> _onx_matmul_relu0
Constant(value=[-0.025308...) -> bias2
Add(_onx_matmul_relu0, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-3.806222438812256, max=4.5326056480407715
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='_onx_transpose_weight05' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight05)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='_onx_transpose_weight06' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(_onx_transpose_weight06)
init: name='_onx_transpose_weight023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
MatMul(norm_1, _onx_transpose_weight0) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> _reshape_init1_s_0
Constant(value=[0.0]) -> _reshape_init1_s_20
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, _reshape_init1_s_20) -> eq
MatMul(norm_1, _onx_transpose_weight02) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, _onx_transpose_weight03) -> value
MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> _reshape_init1_s_02
Constant(value=[0.0]) -> _reshape_init1_s_202
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
Equal(slice_22, _reshape_init1_s_202) -> eq2
MatMul(norm_1, _onx_transpose_weight04) -> query2
MatMul(norm_1, _onx_transpose_weight022) -> key2
Transpose(key2, perm=[0,2,1]) -> transpose2
MatMul(query2, transpose2) -> matmul2
Mul(matmul2, _reshape_init1_s_02) -> _onx_mul_matmul02
Where(eq2, init1_s1_2, _onx_mul_matmul02) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, _onx_transpose_weight032) -> value2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, _onx_transpose_weight05) -> _onx_matmul_cat0
Constant(value=[0.0656113...) -> bias
Add(_onx_matmul_cat0, bias) -> attention
Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, _onx_transpose_weight06) -> _onx_matmul_layer_norm_10
Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, _onx_transpose_weight023) -> _onx_matmul_relu0
Constant(value=[-0.025308...) -> bias2
Add(_onx_matmul_relu0, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2
),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-ZWA.optimize] start with 73 nodes
[GraphBuilder-ZWA.optimize] #patterns=66
[GraphBuilder-ZWA.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-ZWA.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-ZWA.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-ZWA.optimize] start with 53 nodes, 28 initializers, 66 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 1/66 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 2/66 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 3/66 - P0 - CastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 4/66 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 5/66 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 6/66 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 7/66 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 8/66 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 9/66 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 10/66 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 11/66 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 12/66 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 13/66 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 14/66 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 15/66 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 16/66 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 17/66 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 18/66 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 19/66 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 20/66 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 21/66 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 22/66 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 23/66 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 24/66 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 25/66 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 26/66 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 27/66 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 28/66 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 29/66 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 30/66 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 31/66 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 32/66 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 33/66 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 34/66 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 35/66 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 36/66 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 37/66 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 38/66 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 39/66 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 40/66 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 41/66 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 42/66 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 43/66 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 44/66 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 45/66 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 46/66 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 47/66 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 48/66 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 49/66 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 50/66 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 51/66 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 52/66 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 53/66 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 54/66 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 55/66 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 56/66 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 57/66 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 58/66 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 59/66 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 60/66 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 61/66 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 62/66 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 63/66 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 64/66 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 65/66 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern 66/66 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-ZWA.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-ZWA.optimize] increase priority to 1
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-ZWA.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-ZWA.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-ZWA.optimize] increase priority to 2
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-ZWA.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-ZWA.optimize] increase priority to 3
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-ZWA.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-ZWA.optimize] done after 8 iterations with 29 nodes in 0.033
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.000173841996002011
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.00043014800030505285
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00024007700267247856
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.00044869499834021553
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00013733400555793196
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014058530068723485
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0001476519973948598
STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0012261480078450404
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.00028686600126093253
STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.561399742262438e-05
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002683900165720843
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00021369899332057685
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013751999358646572
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012267099373275414
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00042915101221296936
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002062719941022806
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00039391099562635645
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002374749892624095
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014003500109538436
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023889500880613923
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00020226900232955813
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012630200217245147
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001336939967586659
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001983599941013381
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012818598770536482
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013789199147140607
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.229200516827404e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00018165100482292473
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00036033200012752786
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.089899397920817e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.9845006035175174e-05
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0017957130112336017
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0026028189895441756
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=6.459995347540826e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001596389920450747
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.00223082799493568
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00026810899726115167
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001514590039732866
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002048662005108781
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.953899380983785e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006744000129401684
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003577850147848949
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002119630080414936
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002014479978242889
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013585301348939538
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018921001174021512
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018153598648495972
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004234040097799152
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.0284998754505068e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030415500077651814
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020345000666566193
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027579999732552096
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002137810006388463
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019520599016686901
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006518419977510348
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014285900397226214
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000290568990749307
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00020128698815824464
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001317569985985756
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013952799781691283
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00408375401457306
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001676660103839822
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001409980104654096
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022318900300888345
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019362599414307624
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00042318199848523363
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002227019940619357
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021736099006375298
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.86179948085919e-05
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032966199796646833
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034290101029910147
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002329110211576335
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033013800566550344
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002353209929424338
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002007650036830455
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0007423379938700236
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 29 x 1t
INIT: 1 x 7t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-ZWA.remove_unused] remove_initializer 1:5/30:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 2:6/30:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 3:7/30:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 4:8/30:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 5:9/30:init7_s1_-1:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 6:10/30:init1_s1_:float32[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 7:11/30:init1_s1_2:float32[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 8:16/30:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 9:22/30:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-ZWA.optimize] done with 29 nodes in 0.042
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add00, unused, unused2, add
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, _onx_transpose_p_decoder_attention_linear_weight0) -> _onx_matmul_cat0
Add(_onx_matmul_cat0, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_100, unused3, unused4, add_1
MatMul(_onx_div_sub_add_100, _onx_transpose_p_decoder_feed_forward_linear_1_weight0) -> _onx_matmul_layer_norm_10
Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, _onx_transpose_p_decoder_feed_forward_linear_2_weight0) -> _onx_matmul_relu0
Add(_onx_matmul_relu0, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft]
which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='_onx_transpose_weight05' type=float32 shape=(32, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight05)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='_onx_transpose_weight06' type=float32 shape=(16, 128) -- GraphBuilder.make_local_function/from(_onx_transpose_weight06)
init: name='_onx_transpose_weight023' type=float32 shape=(128, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight023)
init: name='init1_s16_3' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_2' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,) -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_, init1_s16_2, epsilon=0.00) -> norm_1, unused, unused2, embedding
MatMul(norm_1, _onx_transpose_weight0) -> query
MatMul(norm_1, _onx_transpose_weight02) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, _onx_transpose_weight03) -> value
MatMul(softmax, value) -> attention_0
MatMul(norm_1, _onx_transpose_weight04) -> query2
MatMul(norm_1, _onx_transpose_weight022) -> key2
FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul02
MatMul(norm_1, _onx_transpose_weight032) -> value2
Equal(slice_4, _reshape_init1_s_202) -> eq2
Where(eq2, init1_s1_2, _onx_mul_matmul02) -> masked_fill2
Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(softmax2, value2) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
MatMul(cat, _onx_transpose_weight05) -> _onx_matmul_cat0
Add(_onx_matmul_cat0, bias) -> attention
Add(attention, embedding) -> add_1
LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
MatMul(norm_2, _onx_transpose_weight06) -> _onx_matmul_layer_norm_10
Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_1
Relu(linear_1) -> relu
MatMul(relu, _onx_transpose_weight023) -> _onx_matmul_relu0
Add(_onx_matmul_relu0, bias2) -> feed_forward
Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-VNC.optimize] start with 73 nodes
[GraphBuilder-VNC.optimize] #patterns=66
[GraphBuilder-VNC.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-VNC.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-VNC.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VNC.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-VNC.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VNC.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-VNC.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-VNC.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-VNC.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-VNC.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-VNC.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-VNC.optimize] start with 53 nodes, 28 initializers, 66 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 1/66 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 2/66 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 3/66 - P0 - CastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 4/66 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 5/66 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 6/66 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 7/66 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 8/66 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 9/66 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 10/66 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 11/66 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 12/66 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 13/66 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 14/66 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 15/66 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 16/66 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 17/66 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 18/66 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 19/66 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 20/66 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 21/66 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 22/66 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 23/66 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 24/66 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 25/66 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 26/66 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 27/66 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 28/66 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 29/66 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 30/66 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 31/66 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 32/66 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 33/66 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 34/66 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 35/66 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 36/66 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 37/66 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 38/66 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 39/66 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 40/66 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 41/66 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 42/66 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 43/66 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 44/66 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 45/66 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 46/66 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 47/66 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 48/66 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 49/66 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 50/66 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 51/66 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 52/66 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 53/66 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 54/66 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 55/66 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 56/66 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 57/66 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 58/66 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 59/66 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 60/66 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 61/66 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 62/66 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 63/66 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 64/66 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 65/66 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern 66/66 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-VNC.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-VNC.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.008 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-VNC.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-VNC.optimize] increase priority to 1
[GraphBuilderPatternOptimization-VNC.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-VNC.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VNC.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-VNC.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VNC.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-VNC.optimize] increase priority to 2
[GraphBuilderPatternOptimization-VNC.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-VNC.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VNC.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-VNC.optimize] increase priority to 3
[GraphBuilderPatternOptimization-VNC.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-VNC.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-VNC.optimize] done after 8 iterations with 29 nodes in 0.038
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00026056999195134267
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0005157869964023121
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00030466200405498967
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0005376979970606044
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00022099500347394496
STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014846519916318357
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00021267700503813103
STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0015150120161706582
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.0003633899978012778
STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.4903994926717132e-05
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0004113059985684231
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00028710500191664323
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001591440086485818
STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002895809957408346
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004520799921010621
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002281540000694804
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004967470013070852
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00029728899971814826
STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015165000513661653
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00026441601221449673
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00024405600561294705
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001320449955528602
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001397379965055734
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002740909985732287
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014128500333754346
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014789899432798848
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.449100328609347e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00014306599769042805
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00035403600486461073
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.892500434536487e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.7320998166687787e-05
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00228474000323331
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003075333996093832
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.80199479777366e-06
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001459740087739192
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0026902310055447742
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003888280116370879
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018430400814395398
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0028149299978394993
STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.969700189074501e-05
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006638829945586622
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035725799534702674
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002011279866565019
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024763400870142505
STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001439319967175834
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020257899450371042
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015912499657133594
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004544649927993305
STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.912900188472122e-05
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032406898390036076
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033377300132997334
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002825109986588359
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026791998971020803
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020742399647133425
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007477519975509495
STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014944799477234483
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028798499261029065
STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0001661589994910173
STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013994799519423395
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015538500883849338
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004016680999484379
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013662099809153005
STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014305799413705245
STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002699020114960149
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016476198652526364
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00046992699935799465
STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022256100055528805
STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022111900034360588
STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.120399772655219e-05
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003470090086921118
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034353700175415725
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026978999812854454
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026462300593266264
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023889799922471866
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002689320099307224
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0008166290062945336
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 29 x 1t
INIT: 1 x 7t
NODE: 4 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 11 x MatMul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-VNC.remove_unused] remove_initializer 1:5/30:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 2:6/30:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 3:7/30:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 4:8/30:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 5:9/30:init7_s1_-1:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 6:10/30:init1_s1_:float32[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 7:11/30:init1_s1_2:float32[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 8:16/30:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 9:22/30:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-VNC.optimize] done with 29 nodes in 0.050
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add00, unused, unused2, add
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, _onx_transpose_p_decoder_attention_linear_weight0) -> _onx_matmul_cat0
Add(_onx_matmul_cat0, decoder.attention.linear.bias) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_100, unused3, unused4, add_1
MatMul(_onx_div_sub_add_100, _onx_transpose_p_decoder_feed_forward_linear_1_weight0) -> _onx_matmul_layer_norm_10
Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, _onx_transpose_p_decoder_feed_forward_linear_2_weight0) -> _onx_matmul_relu0
Add(_onx_matmul_relu0, decoder.feed_forward.linear_2.bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
res1, res2, align, dc = compare_onnx_execution(
onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 68 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 68 results (first model)
[compare_onnx_execution] got 58 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 68 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 - | INITIA float32 1:1 ?AAA in |
004 = | INITIA float32 2:16x16 ABAB _o | INITIA float32 2:16x16 ABAB _o
005 = | INITIA float32 2:16x16 AAAA _o | INITIA float32 2:16x16 AAAA _o
006 = | INITIA float32 2:16x16 ABAB _o | INITIA float32 2:16x16 ABAB _o
007 ~ | INITIA float32 1:1 AAAA _r | INITIA float32 2:30x30 KGSP sl
008 = | INITIA float32 1:1 AAAA _r | INITIA float32 1:1 AAAA _r
009 = | INITIA float32 2:16x16 AACB _o | INITIA float32 2:16x16 AACB _o
010 = | INITIA float32 2:16x16 BBAZ _o | INITIA float32 2:16x16 BBAZ _o
011 = | INITIA float32 2:16x16 BZYZ _o | INITIA float32 2:16x16 BZYZ _o
012 ~ | INITIA float32 1:1 AAAA _r | INITIA float32 2:30x30 KGSP sl
013 = | INITIA float32 1:1 AAAA _r | INITIA float32 1:1 AAAA _r
014 = | INITIA float32 2:32x16 ZAAB _o | INITIA float32 2:32x16 ZAAB _o
015 = | INITIA float32 2:16x128 AFZT _o | INITIA float32 2:16x128 AFZT _o
016 = | INITIA float32 2:128x16 AAAB _o | INITIA float32 2:128x16 AAAB _o
017 = | INITIA float32 1:16 EEEE in | INITIA float32 1:16 EEEE in
018 = | INITIA float32 1:16 AAAA in | INITIA float32 1:16 AAAA in
019 - | INITIA int64 1:2 AAAA Sl |
020 - | INITIA int64 1:2 EEAA Sl |
021 - | INITIA int64 1:2 ABAA Sl |
022 - | INITIA int64 1:2 AAAA Sl |
023 - | INITIA int64 1:2 EEAA Sl |
024 - | INITIA int64 1:2 ABAA Sl |
025 = | INITIA float32 2:1024x16 SCTO em | INITIA float32 2:1024x16 SCTO em
026 = | INITIA float32 2:1024x16 BMVK em | INITIA float32 2:1024x16 BMVK em
027 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
028 = | INITIA float32 1:128 AAAA de | INITIA float32 1:128 AAAA de
029 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
030 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
031 = | RESULT float32 3:1x30x16 ASVG Gather em | RESULT float32 3:1x30x16 ASVG Gather em
032 = | RESULT float32 3:1x30x16 HATW Gather em | RESULT float32 3:1x30x16 HATW Gather em
033 ~ | RESULT float32 3:1x30x16 ISNB Add ad | RESULT float32 3:1x30x16 CYAA SkipLayerNormal _o
034 ~ | RESULT float32 3:1x30x16 CYAA LayerNormalizat _o | RESULT float32 3:1x30x1 ACAA SkipLayerNormal un
035 ~ | RESULT float32 3:1x30x16 WAXA MatMul li | RESULT float32 3:1x30x1 GFGE SkipLayerNormal un
036 ~ | RESULT float32 3:1x30x16 FGFA MatMul li | RESULT float32 3:1x30x16 ISNB SkipLayerNormal ad
037 ~ | RESULT float32 3:1x30x16 XTUU MatMul li | RESULT float32 3:1x30x16 WAXA MatMul li
038 ~ | RESULT float32 3:1x16x30 NFYC Transpose tr | RESULT float32 3:1x30x16 FGFA MatMul li
039 ~ | RESULT float32 3:1x30x30 EAHJ MatMul ma | RESULT float32 3:1x30x30 BUIC FusedMatMul _o
040 ~ | RESULT float32 3:1x30x30 BUIC Mul _o | RESULT float32 3:1x30x16 XTUU MatMul li
041 - | RESULT float32 2:30x30 KGSP Slice sl |
042 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
043 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
044 = | RESULT float32 3:1x30x30 IHHH Softmax so | RESULT float32 3:1x30x30 IHHH Softmax so
045 = | RESULT float32 3:1x30x16 TVUU MatMul ma | RESULT float32 3:1x30x16 TVUU MatMul ma
046 = | RESULT float32 3:1x30x16 NBDY MatMul li | RESULT float32 3:1x30x16 NBDY MatMul li
047 = | RESULT float32 3:1x30x16 VASX MatMul li | RESULT float32 3:1x30x16 VASX MatMul li
048 ~ | RESULT float32 3:1x30x16 IZSY MatMul li | RESULT float32 3:1x30x30 RCYW FusedMatMul _o
049 ~ | RESULT float32 3:1x16x30 AOUZ Transpose tr | RESULT float32 3:1x30x16 IZSY MatMul li
050 ~ | RESULT float32 3:1x30x30 QLQJ MatMul ma | RESULT bool 2:30x30 HLZC Equal eq
051 ~ | RESULT float32 3:1x30x30 RCYW Mul _o | RESULT float32 3:1x30x30 ???? Where ma
052 ~ | RESULT float32 2:30x30 KGSP Slice sl | RESULT float32 3:1x30x30 IGHH Softmax so
053 - | RESULT bool 2:30x30 HLZC Equal eq |
054 ~ | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x16 QCGG MatMul ma
055 ~ | RESULT float32 3:1x30x30 IGHH Softmax so | RESULT float32 3:1x30x32 IYZC Concat ca
056 ~ | RESULT float32 3:1x30x16 QCGG MatMul ma | RESULT float32 3:1x30x16 AAAC MatMul _o
057 ~ | RESULT float32 3:1x30x32 IYZC Concat ca | RESULT float32 3:1x30x16 ZYAA Add li
058 ~ | RESULT float32 3:1x30x16 AAAC MatMul _o | RESULT float32 3:1x30x16 CYAA SkipLayerNormal _o
059 ~ | RESULT float32 3:1x30x16 ZYAA Add li | RESULT float32 3:1x30x1 ABAA SkipLayerNormal un
060 ~ | RESULT float32 3:1x30x16 GQNC Add ad | RESULT float32 3:1x30x1 GFGE SkipLayerNormal un
061 ~ | RESULT float32 3:1x30x16 CYAA LayerNormalizat _o | RESULT float32 3:1x30x16 GQNC SkipLayerNormal ad
062 = | RESULT float32 3:1x30x128 LJDZ MatMul _o | RESULT float32 3:1x30x128 LJDZ MatMul _o
063 = | RESULT float32 3:1x30x128 JHBX Add li | RESULT float32 3:1x30x128 JHBX Add li
064 = | RESULT float32 3:1x30x128 NHIT Relu re | RESULT float32 3:1x30x128 NHIT Relu re
065 = | RESULT float32 3:1x30x16 CDCA MatMul _o | RESULT float32 3:1x30x16 CDCA MatMul _o
066 = | RESULT float32 3:1x30x16 DFEC Add li | RESULT float32 3:1x30x16 DFEC Add li
067 = | RESULT float32 3:1x30x16 KVRE Add ou | RESULT float32 3:1x30x16 KVRE Add ou
068 = | OUTPUT float32 3:1x30x16 KVRE ou | OUTPUT float32 3:1x30x16 KVRE ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 2.474 seconds)
Related examples

to_onnx and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant