Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
import torch
from onnxruntime import InferenceSession
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.helpers import pretty_onnx, max_diff
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.848771095275879, max=4.01364278793335
First conversion to ONNX¶
The conversion relies on torch.export.export()
.
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init1_s_' type=float32 shape=() -- array([0.25], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2' type=float32 shape=() -- array([0.], dtype=float32)-- shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1
init: name='init7_s3_1_30_-1' type=int64 shape=(3,) -- array([ 1, 30, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1
init: name='init7_s2_-1_128' type=int64 shape=(2,) -- array([ -1, 128])-- MatMulAddPattern.new_shape.1
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.attention.0.query.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='decoder.attention.attention.0.key.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='decoder.attention.attention.0.value.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='decoder.attention.attention.1.query.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='decoder.attention.attention.1.key.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='decoder.attention.attention.1.value.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='decoder.attention.linear.weight' type=float32 shape=(16, 32)-- DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.weight' type=float32 shape=(128, 16)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.weight' type=float32 shape=(16, 128)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add00
Transpose(decoder.attention.attention.0.query.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_0_query_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
Transpose(decoder.attention.attention.0.key.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_0_key_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Transpose(decoder.attention.attention.0.value.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_0_value_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_0
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_20
Equal(slice_2, _reshape_init1_s_20) -> eq
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, linear_2) -> matmul_1
Transpose(decoder.attention.attention.1.query.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_1_query_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
Transpose(decoder.attention.attention.1.key.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_1_key_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Transpose(decoder.attention.attention.1.value.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_1_value_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_02
Mul(matmul_2, _reshape_init1_s_02) -> _onx_mul_matmul_20
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start2
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end2
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis2
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_4
Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_202
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, decoder.attention.linear.weight, decoder.attention.linear.bias, transB=1) -> MatMulAddPattern--cat2
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_100
Reshape(_onx_div_sub_add_100, init7_s2_-1_16) -> MatMulAddPattern--_onx_div_sub_add_100
Gemm(MatMulAddPattern--_onx_div_sub_add_100, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias, transB=1) -> SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002
Relu(SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002) -> SwitchReshapeActivationPatternL_linear_7
Reshape(SwitchReshapeActivationPatternL_linear_7, init7_s3_1_30_-1) -> relu
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_2.bias, transB=1) -> MatMulAddPattern--relu2
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-4.848771095275879, max=4.01364278793335
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='decoder.attention.attention.0.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.query.weight)
init: name='decoder.attention.attention.0.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.key.weight)
init: name='decoder.attention.attention.0.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.value.weight)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='decoder.attention.attention.1.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.query.weight)
init: name='decoder.attention.attention.1.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.key.weight)
init: name='decoder.attention.attention.1.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.value.weight)
init: name='decoder.attention.linear.weight' type=float32 shape=(16, 32)-- GraphBuilder.make_local_function/from(decoder.attention.linear.weight)
init: name='decoder.feed_forward.linear_1.weight' type=float32 shape=(128, 16)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.weight)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.weight' type=float32 shape=(16, 128)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_2.weight)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
__main__.DecoderLayer[aten_local_function](embedding, mask2, mask, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
Constant(value=[1.0, 1.0,...) -> init1_s16_
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'weight'
Transpose(weight, perm=[1,0]) -> _onx_transpose_weight0
MatMul(layer_norm, _onx_transpose_weight0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
Constant(value=0.25) -> init1_s_
Constant(value=[1]) -> init7_s1_1
Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_0
Constant(value=[0]) -> init7_s1_0
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30]) -> init7_s1_30
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Constant(value=0.0) -> init1_s_2
Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_20
Constant(value=[-inf]) -> init1_s1_
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.query.weight) -> query
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.key.weight) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.value.weight) -> value
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, _reshape_init1_s_20) -> eq
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'cat'
input: 'weight'
Constant(value=[-0.018796...) -> bias
Constant(value=[-1, 32]) -> init7_s2_-1_32
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, weight, bias, transB=1) -> MatMulAddPattern--cat2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> output
Constant(value=[-0.018796...) -> decoder.attention.linear.bias
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask2'
input: 'mask'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
__main__.AttentionBlock[aten_local_function](layer_norm, mask, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, mask2, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
Linear_2[aten_local_function](cat, decoder.attention.linear.weight) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'weight'
input: 'bias'
Constant(value=[-1, 16]) -> init7_s2_-1_16
Reshape(layer_norm_1, init7_s2_-1_16) -> MatMulAddPattern--layer_norm_1
Gemm(MatMulAddPattern--layer_norm_1, weight, bias, transB=1) -> MatMulAddPattern--layer_norm_12
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--layer_norm_12, init7_s3_1_30_-1) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=Linear_2_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'relu'
input: 'weight'
Constant(value=[0.0637946...) -> bias
Constant(value=[-1, 128]) -> init7_s2_-1_128
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, weight, bias, transB=1) -> MatMulAddPattern--relu2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> output
Constant(value=[0.0637946...) -> decoder.feed_forward.linear_2.bias
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.feed_forward.linear_1.bias'
Linear_3[aten_local_function](layer_norm_1, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> linear_1
ReLU[aten_local_function](linear_1) -> relu
Linear_2_2[aten_local_function](relu, decoder.feed_forward.linear_2.weight) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
input: 'mask2'
input: 'mask'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
input: 'decoder.feed_forward.linear_1.bias'
LayerNorm[aten_local_function](add) -> norm_1
__main__.MultiAttentionBlock[aten_local_function](norm_1, mask2, mask, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention
Add(attention, add) -> add_1
LayerNorm[aten_local_function](add_1) -> norm_2
__main__.FeedForward[aten_local_function](norm_2, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> feed_forward
Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-4.848771095275879, max=4.01364278793335
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='decoder.attention.attention.0.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.query.weight)
init: name='decoder.attention.attention.0.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.key.weight)
init: name='decoder.attention.attention.0.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.value.weight)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='decoder.attention.attention.1.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.query.weight)
init: name='decoder.attention.attention.1.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.key.weight)
init: name='decoder.attention.attention.1.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.value.weight)
init: name='decoder.attention.linear.weight' type=float32 shape=(16, 32)-- GraphBuilder.make_local_function/from(decoder.attention.linear.weight)
init: name='decoder.feed_forward.linear_1.weight' type=float32 shape=(128, 16)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.weight)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.weight' type=float32 shape=(16, 128)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_2.weight)
Constant(value=[1]) -> init7_s1_1__11
Gather(embedding.embedding.weight, input_ids) -> embedding__1
Gather(embedding.pe.weight, input_ids) -> pe__1
Add(embedding__1, pe__1) -> embedding
Constant(value=[1.0, 1.0,...) -> init1_s16___5
Constant(value=[0.0, 0.0,...) -> init1_s16_2__5
LayerNormalization(embedding, init1_s16___5, init1_s16_2__5, axis=-1, epsilon=0.00, stash_type=1) -> norm_1__4
Constant(value=0.25) -> init1_s___7
Constant(value=[1]) -> init7_s1_1__7
Reshape(init1_s___7, init7_s1_1__7) -> _reshape_init1_s_0__7
Constant(value=[0]) -> init7_s1_0__7
Concat(init7_s1_0__7, init7_s1_0__7, axis=0) -> SliceSlicePattern_init7_s1_0_start__7
Constant(value=[30]) -> init7_s1_30__7
Concat(init7_s1_30__7, init7_s1_30__7, axis=0) -> SliceSlicePattern_init7_s1_30_end__7
Constant(value=0.0) -> init1_s_2__7
Reshape(init1_s_2__7, init7_s1_1__7) -> _reshape_init1_s_20__7
Constant(value=[-inf]) -> init1_s1___7
Transpose(decoder.attention.attention.0.query.weight, perm=[1,0]) -> _onx_transpose_weight0__8
MatMul(norm_1__4, _onx_transpose_weight0__8) -> query__7
Transpose(decoder.attention.attention.0.key.weight, perm=[1,0]) -> _onx_transpose_weight0__9
MatMul(norm_1__4, _onx_transpose_weight0__9) -> key__7
Transpose(key__7, perm=[0,2,1]) -> transpose__7
MatMul(query__7, transpose__7) -> matmul__7
Mul(matmul__7, _reshape_init1_s_0__7) -> _onx_mul_matmul0__7
Transpose(decoder.attention.attention.0.value.weight, perm=[1,0]) -> _onx_transpose_weight0__10
MatMul(norm_1__4, _onx_transpose_weight0__10) -> value__7
Concat(init7_s1_0__7, init7_s1_1__7, axis=0) -> SliceSlicePattern_init7_s1_1_axis__7
Slice(mask, SliceSlicePattern_init7_s1_0_start__7, SliceSlicePattern_init7_s1_30_end__7, SliceSlicePattern_init7_s1_1_axis__7) -> slice_2__7
Equal(slice_2__7, _reshape_init1_s_20__7) -> eq__7_1
Where(eq__7_1, init1_s1___7, _onx_mul_matmul0__7) -> masked_fill__7
Softmax(masked_fill__7, axis=-1) -> softmax__7
MatMul(softmax__7, value__7) -> attention_0__6
Constant(value=0.25) -> init1_s___11
Reshape(init1_s___11, init7_s1_1__11) -> _reshape_init1_s_0__11
Constant(value=[0]) -> init7_s1_0__11
Concat(init7_s1_0__11, init7_s1_0__11, axis=0) -> SliceSlicePattern_init7_s1_0_start__11
Constant(value=[30]) -> init7_s1_30__11
Concat(init7_s1_30__11, init7_s1_30__11, axis=0) -> SliceSlicePattern_init7_s1_30_end__11
Constant(value=0.0) -> init1_s_2__11
Reshape(init1_s_2__11, init7_s1_1__11) -> _reshape_init1_s_20__11
Constant(value=[-inf]) -> init1_s1___11
Transpose(decoder.attention.attention.1.query.weight, perm=[1,0]) -> _onx_transpose_weight0__12
MatMul(norm_1__4, _onx_transpose_weight0__12) -> query__11
Transpose(decoder.attention.attention.1.key.weight, perm=[1,0]) -> _onx_transpose_weight0__13
MatMul(norm_1__4, _onx_transpose_weight0__13) -> key__11
Transpose(key__11, perm=[0,2,1]) -> transpose__11
MatMul(query__11, transpose__11) -> matmul__11
Mul(matmul__11, _reshape_init1_s_0__11) -> _onx_mul_matmul0__11
Transpose(decoder.attention.attention.1.value.weight, perm=[1,0]) -> _onx_transpose_weight0__14
MatMul(norm_1__4, _onx_transpose_weight0__14) -> value__11
Concat(init7_s1_0__11, init7_s1_1__11, axis=0) -> SliceSlicePattern_init7_s1_1_axis__11
Slice(mask2, SliceSlicePattern_init7_s1_0_start__11, SliceSlicePattern_init7_s1_30_end__11, SliceSlicePattern_init7_s1_1_axis__11) -> slice_2__11
Equal(slice_2__11, _reshape_init1_s_20__11) -> eq__11_2
Where(eq__11_2, init1_s1___11, _onx_mul_matmul0__11) -> masked_fill__11
Softmax(masked_fill__11, axis=-1) -> softmax__11
MatMul(softmax__11, value__11) -> attention_1__6
Concat(attention_0__6, attention_1__6, axis=-1) -> cat__6_0
Constant(value=[-0.018796...) -> bias__15
Constant(value=[-1, 32]) -> init7_s2_-1_32__15
Reshape(cat__6_0, init7_s2_-1_32__15) -> MatMulAddPattern--cat__15
Gemm(MatMulAddPattern--cat__15, decoder.attention.linear.weight, bias__15, transB=1) -> MatMulAddPattern--cat2__15
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1__15
Reshape(MatMulAddPattern--cat2__15, init7_s3_1_30_-1__15) -> attention__4
Add(attention__4, embedding) -> add_1__4
Constant(value=[-0.018796...) -> decoder.attention.linear.bias__15
Constant(value=[1.0, 1.0,...) -> init1_s16___16
Constant(value=[0.0, 0.0,...) -> init1_s16_2__16
LayerNormalization(add_1__4, init1_s16___16, init1_s16_2__16, axis=-1, epsilon=0.00, stash_type=1) -> norm_2__4
Constant(value=[-1, 16]) -> init7_s2_-1_16__18
Reshape(norm_2__4, init7_s2_-1_16__18) -> MatMulAddPattern--layer_norm_1__18
Gemm(MatMulAddPattern--layer_norm_1__18, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias, transB=1) -> MatMulAddPattern--layer_norm_12__18
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1__18
Reshape(MatMulAddPattern--layer_norm_12__18, init7_s3_1_30_-1__18) -> linear_1__17
Relu(linear_1__17) -> relu__17
Constant(value=[0.0637946...) -> bias__20
Constant(value=[-1, 128]) -> init7_s2_-1_128__20
Reshape(relu__17, init7_s2_-1_128__20) -> MatMulAddPattern--relu__20
Gemm(MatMulAddPattern--relu__20, decoder.feed_forward.linear_2.weight, bias__20, transB=1) -> MatMulAddPattern--relu2__20
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1__20
Reshape(MatMulAddPattern--relu2__20, init7_s3_1_30_-1__20) -> feed_forward__4
Add(feed_forward__4, add_1__4) -> output_0
Constant(value=[0.0637946...) -> decoder.feed_forward.linear_2.bias__20
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2
),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-SLK.optimize] start with 73 nodes
[GraphBuilder-SLK.optimize] #patterns=62
[GraphBuilder-SLK.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 3:5/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 4:7/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 9:17/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 10:19/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 11:21/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-SLK.remove_unused] remove_initializer 12:23/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 13:25/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 14:27/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 15:29/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 16:31/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-SLK.remove_unused] remove_initializer 17:33/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-SLK.remove_unused] remove_initializer 18:35/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 1:4/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 2:5/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 3:6/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 4:7/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 5:8/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 6:9/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 7:10/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-SLK.remove_unused] remove_initializer 8:14/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 9:16/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-SLK.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-SLK.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-SLK.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-SLK.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-SLK.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-SLK.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-SLK.optimize] start with 53 nodes, 28 initializers, 62 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 1/62 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 2/62 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 3/62 - P0 - CastPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 4/62 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 5/62 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 6/62 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 7/62 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 8/62 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 9/62 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 10/62 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 11/62 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 12/62 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 13/62 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 14/62 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 15/62 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 16/62 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 17/62 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 18/62 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 19/62 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 20/62 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 21/62 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 22/62 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 23/62 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 24/62 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 25/62 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 26/62 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 27/62 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 28/62 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 29/62 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 30/62 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 31/62 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 32/62 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 33/62 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 34/62 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 35/62 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 36/62 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 37/62 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 38/62 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 39/62 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 40/62 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 41/62 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 42/62 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 43/62 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 44/62 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 45/62 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 46/62 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 47/62 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 48/62 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 49/62 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 50/62 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 51/62 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 52/62 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 53/62 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 54/62 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 55/62 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 56/62 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 57/62 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 58/62 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 59/62 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 60/62 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 61/62 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-SLK.optimize] use pattern 62/62 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-SLK.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-SLK.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.009 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization-SLK.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-SLK.optimize] increase priority to 1
[GraphBuilderPatternOptimization-SLK.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-SLK.optimize] applies 5 matches, 2*LayerNormalizationPattern, 3*MatMulAddPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-SLK.optimize] iteration 3: 38 nodes, priority=1
[GraphBuilderPatternOptimization-SLK.optimize] applies 3 matches, 3*GemmTransposePattern - time=0.003 | max_time=GeluOrtPattern:0.000
[GraphBuilderPatternOptimization-SLK.optimize] iteration 4: 41 nodes, priority=1
[GraphBuilderPatternOptimization-SLK.optimize] applies 1 matches, [0]=MatchResult: SwitchReshapeActivationPattern replaces ['Gemm', 'Reshape', 'Relu'] - time=0.003 | max_time=MatMulAddPattern:0.000
[GraphBuilderPatternOptimization-SLK.optimize] iteration 5: 41 nodes, priority=1
[GraphBuilderPatternOptimization-SLK.optimize] increase priority to 2
[GraphBuilderPatternOptimization-SLK.optimize] iteration 6: 41 nodes, priority=2
[GraphBuilderPatternOptimization-SLK.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=FusedMatMulPattern:0.000
[GraphBuilderPatternOptimization-SLK.optimize] iteration 7: 37 nodes, priority=2
[GraphBuilderPatternOptimization-SLK.optimize] increase priority to 3
[GraphBuilderPatternOptimization-SLK.optimize] iteration 8: 37 nodes, priority=3
[GraphBuilderPatternOptimization-SLK.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-SLK.optimize] done after 9 iterations with 37 nodes in 0.057
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002774419999695965
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.000743970000257832
STAT apply_GemmTransposePattern +6 -3 #it=1 maxmatch=2 i=3 - time=0.0006179430010888609
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0003867269988404587
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.000427855000452837
STAT apply_MatMulAddPattern +9 -6 #it=1 maxmatch=4 i=3 - time=0.0010611870002321666
STAT apply_SwitchReshapeActivationPattern +3 -3 #it=1 maxmatch=0 i=1 - time=0.00033681199965940323
STAT build_graph_for_pattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.002385127001616638
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0001356900002065231
STAT check_pattern_A0 +0 -0 #it=5 maxmatch=0 i=0 - time=0.0023776759990141727
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.0004244499996275408
STAT match_BatchNormalizationPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.0005061700012447545
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.00040999799875862664
STAT match_BiasGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003304389983895817
STAT match_BiasSoftmaxPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003006020006068866
STAT match_CastCastBinaryPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.000638119001450832
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00041285899987997254
STAT match_CastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0007076430001689005
STAT match_CastPattern +0 -0 #it=9 maxmatch=2 i=2 - time=0.0005087529989395989
STAT match_ClipClipPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0003066160015805508
STAT match_ComputationCastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0004415840003275662
STAT match_ConvBiasNullPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.00041970299935201183
STAT match_DropoutPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0002615959992908756
STAT match_ExpandBroadcastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0002902229998653638
STAT match_ExpandPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.0004202009995424305
STAT match_ExpandSwapPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00027260400111117633
STAT match_FastGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00034409799900458893
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00012447899916878669
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0002562809986557113
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0005595439997705398
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.598800064239185e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.776499998930376e-05
STAT match_GeluErfPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0029429580017676926
STAT match_GeluOrtPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.004003896000540408
STAT match_GeluPattern +0 -0 #it=9 maxmatch=2 i=0 - time=9.273999239667319e-06
STAT match_GemmTransposePattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.00046273800126073183
STAT match_IdentityPattern +0 -0 #it=9 maxmatch=6 i=4 - time=0.002819929999532178
STAT match_LayerNormalizationPattern +0 -0 #it=7 maxmatch=2 i=2 - time=0.00043776200072898064
STAT match_LayerNormalizationScalePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0003266580006311415
STAT match_LeakyReluPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00307208700087358
STAT match_MatMulAddPattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.0010426249982629088
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0009158190014204592
STAT match_MulMulMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0005581060004260507
STAT match_MulMulMulScalarPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00038291499822662445
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00042769800074893283
STAT match_QuickGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002753689987002872
STAT match_ReduceReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0003962739992857678
STAT match_ReduceSumNormalizePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00028343400117591955
STAT match_Reshape2Of3Pattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0007751890007057227
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0004962349994457327
STAT match_ReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0007220110010166536
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00044874400009575766
STAT match_ReshapeReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0006011940013195272
STAT match_RotaryConcatPartPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00037213799987512175
STAT match_SameChildrenPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0010660900006769225
STAT match_SequenceConstructAtPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002859010010070051
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00042446499810466776
STAT match_SliceSlicePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00026982899998984067
STAT match_SlicesSplitPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00026657899979909416
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.005545582999729959
STAT match_SoftmaxGradPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00027877800039277645
STAT match_SplitConcatPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00025085400102398125
STAT match_SqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0004043770013595349
STAT match_Sub1MulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003450860003795242
STAT match_SwitchOrderBinaryPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0007577429987577489
STAT match_SwitchReshapeActivationPattern +0 -0 #it=7 maxmatch=5 i=1 - time=0.0004484370001591742
STAT match_TransposeEqualReshapePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00067874999876949
STAT match_TransposeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0010082369999508956
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0005839910018039518
STAT match_TransposeReshapeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0005249969999567838
STAT match_TransposeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0004713570006060763
STAT match_UnsqueezeEqualPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00043672099855029956
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00043951699990429915
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0009320860008301679
--MODEL: 37 nodes, 1 inputs, 1 outputs, 34 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 29 x 1t
INIT: 5 x 7t
NODE: 3 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 3 x Gemm
NODE: 2 x LayerNormalization
NODE: 8 x MatMul
NODE: 1 x Relu
NODE: 6 x Reshape
NODE: 2 x Softmax
NODE: 3 x Transpose
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
--MODEL: 37 nodes, 1 inputs, 1 outputs, 34 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
INIT: 3 x 7t[2]
INIT: 1 x 7t[3]
NODE: 3 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x Gemm -SIG- 1t[30x128], 1t[16x128], 1t[16]
NODE: 1 x Gemm -SIG- 1t[30x16], 1t[128x16], 1t[128]
NODE: 1 x Gemm -SIG- 1t[30x32], 1t[16x32], 1t[16]
NODE: 2 x LayerNormalization -SIG- 1t[1x30x16], 1t[16], 1t[16]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x Relu -SIG- 1t[30x128]
NODE: 1 x Reshape -SIG- 1t[1x30x128], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x16], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x32], 7t[2]
NODE: 1 x Reshape -SIG- 1t[30x128], 7t[3]
NODE: 2 x Reshape -SIG- 1t[30x16], 7t[3]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 1 x Transpose -SIG- 1t[128x16]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[16x128]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[32x16]-perm=1;0
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
[GraphBuilder-SLK.remove_unused] remove_initializer 1:2/34:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 2:3/34:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 3:5/34:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 4:6/34:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 5:9/34:init7_s1_-1:int64[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 6:10/34:init1_s1_:float32[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 7:11/34:init1_s1_2:float32[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 8:16/34:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 9:22/34:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-SLK.remove_unused] remove_initializer 1:16/28:_onx_transpose_p_decoder_attention_linear_weight0:torch.float32[torch.Size([32, 16])]
[GraphBuilder-SLK.remove_unused] remove_initializer 2:17/28:_onx_transpose_p_decoder_feed_forward_linear_1_weight0:torch.float32[torch.Size([16, 128])]
[GraphBuilder-SLK.remove_unused] remove_initializer 3:18/28:_onx_transpose_p_decoder_feed_forward_linear_2_weight0:torch.float32[torch.Size([128, 16])]
[GraphBuilder-SLK.optimize] done with 34 nodes in 0.068
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1
init: name='init7_s3_1_30_-1' type=int64 shape=(3,) -- array([ 1, 30, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1
init: name='init7_s2_-1_128' type=int64 shape=(2,) -- array([ -1, 128])-- MatMulAddPattern.new_shape.1
init: name='GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(16, 32)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_attention_linear_weight0)##_onx_transpose_p_decoder_attention_linear_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_1_weight0)##_onx_transpose_p_decoder_feed_forward_linear_1_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_2_weight0)##_onx_transpose_p_decoder_feed_forward_linear_2_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add00
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0, decoder.attention.linear.bias, transB=1) -> MatMulAddPattern--cat2
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_100
Reshape(_onx_div_sub_add_100, init7_s2_-1_16) -> MatMulAddPattern--_onx_div_sub_add_100
Gemm(MatMulAddPattern--_onx_div_sub_add_100, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0, decoder.feed_forward.linear_1.bias, transB=1) -> SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002
Relu(SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002) -> SwitchReshapeActivationPatternL_linear_7
Reshape(SwitchReshapeActivationPatternL_linear_7, init7_s3_1_30_-1) -> relu
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0, decoder.feed_forward.linear_2.bias, transB=1) -> MatMulAddPattern--relu2
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft]
which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='GemmTransposePattern--_onx_transpose_weight0' type=float32 shape=(16, 32)-- GraphBuilder.make_local_function/from(GemmTransposePattern--_onx_transpose_weight0)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='GemmTransposePattern--_onx_transpose_weight02' type=float32 shape=(128, 16)-- GraphBuilder.make_local_function/from(GemmTransposePattern--_onx_transpose_weight02)
init: name='GemmTransposePattern--_onx_transpose_weight022' type=float32 shape=(16, 128)-- GraphBuilder.make_local_function/from(GemmTransposePattern--_onx_transpose_weight022)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
__main__.DecoderLayer[aten_local_function](embedding, GemmTransposePattern--_onx_transpose_weight022, GemmTransposePattern--_onx_transpose_weight02, slice_4, slice_2, GemmTransposePattern--_onx_transpose_weight0, _onx_transpose_weight04, _onx_transpose_weight032, _onx_transpose_weight03, _onx_transpose_weight022, _onx_transpose_weight02, _onx_transpose_weight0, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
Constant(value=[1.0, 1.0,...) -> init1_s16_
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: '_onx_transpose_weight0'
MatMul(layer_norm, _onx_transpose_weight0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_2'
input: '_onx_transpose_weight03'
input: '_onx_transpose_weight02'
input: '_onx_transpose_weight0'
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.0]) -> _reshape_init1_s_20
Equal(slice_2, _reshape_init1_s_20) -> eq
Linear[aten_local_function](layer_norm, _onx_transpose_weight0) -> query
Linear[aten_local_function](layer_norm, _onx_transpose_weight02) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
Linear[aten_local_function](layer_norm, _onx_transpose_weight03) -> value
MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'cat'
input: 'GemmTransposePattern--_onx_transpose_weight0'
Constant(value=[-0.018796...) -> bias
Constant(value=[-1, 32]) -> init7_s2_-1_32
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, GemmTransposePattern--_onx_transpose_weight0, bias, transB=1) -> MatMulAddPattern--cat2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> output
Constant(value=[-0.018796...) -> decoder.attention.linear.bias
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_4'
input: 'slice_2'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: '_onx_transpose_weight04'
input: '_onx_transpose_weight032'
input: '_onx_transpose_weight03'
input: '_onx_transpose_weight022'
input: '_onx_transpose_weight02'
input: '_onx_transpose_weight0'
__main__.AttentionBlock[aten_local_function](layer_norm, slice_2, _onx_transpose_weight03, _onx_transpose_weight02, _onx_transpose_weight0) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, slice_4, _onx_transpose_weight032, _onx_transpose_weight022, _onx_transpose_weight04) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
Linear_2[aten_local_function](cat, GemmTransposePattern--_onx_transpose_weight0) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: 'bias'
Constant(value=[-1, 16]) -> init7_s2_-1_16
Reshape(layer_norm_1, init7_s2_-1_16) -> MatMulAddPattern--layer_norm_1
Gemm(MatMulAddPattern--layer_norm_1, GemmTransposePattern--_onx_transpose_weight0, bias, transB=1) -> MatMulAddPattern--layer_norm_12
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--layer_norm_12, init7_s3_1_30_-1) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=Linear_2_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'relu'
input: 'GemmTransposePattern--_onx_transpose_weight0'
Constant(value=[0.0637946...) -> bias
Constant(value=[-1, 128]) -> init7_s2_-1_128
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, GemmTransposePattern--_onx_transpose_weight0, bias, transB=1) -> MatMulAddPattern--relu2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> output
Constant(value=[0.0637946...) -> decoder.feed_forward.linear_2.bias
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: 'GemmTransposePattern--_onx_transpose_weight02'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: 'decoder.feed_forward.linear_1.bias'
Linear_3[aten_local_function](layer_norm_1, GemmTransposePattern--_onx_transpose_weight0, decoder.feed_forward.linear_1.bias) -> linear_1
ReLU[aten_local_function](linear_1) -> relu
Linear_2_2[aten_local_function](relu, GemmTransposePattern--_onx_transpose_weight02) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'add'
input: 'GemmTransposePattern--_onx_transpose_weight022'
input: 'GemmTransposePattern--_onx_transpose_weight02'
input: 'slice_4'
input: 'slice_2'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: '_onx_transpose_weight04'
input: '_onx_transpose_weight032'
input: '_onx_transpose_weight03'
input: '_onx_transpose_weight022'
input: '_onx_transpose_weight02'
input: '_onx_transpose_weight0'
input: 'decoder.feed_forward.linear_1.bias'
LayerNorm[aten_local_function](add) -> norm_1
__main__.MultiAttentionBlock[aten_local_function](norm_1, slice_4, slice_2, GemmTransposePattern--_onx_transpose_weight0, _onx_transpose_weight04, _onx_transpose_weight032, _onx_transpose_weight03, _onx_transpose_weight022, _onx_transpose_weight02, _onx_transpose_weight0) -> attention
Add(attention, add) -> add_1
LayerNorm[aten_local_function](add_1) -> norm_2
__main__.FeedForward[aten_local_function](norm_2, GemmTransposePattern--_onx_transpose_weight022, GemmTransposePattern--_onx_transpose_weight02, decoder.feed_forward.linear_1.bias) -> feed_forward
Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-DZK.optimize] start with 73 nodes
[GraphBuilder-DZK.optimize] #patterns=62
[GraphBuilder-DZK.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 3:5/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 4:7/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 9:17/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 10:19/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 11:21/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-DZK.remove_unused] remove_initializer 12:23/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 13:25/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 14:27/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 15:29/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 16:31/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-DZK.remove_unused] remove_initializer 17:33/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-DZK.remove_unused] remove_initializer 18:35/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 1:4/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 2:5/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 3:6/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 4:7/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 5:8/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 6:9/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 7:10/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-DZK.remove_unused] remove_initializer 8:14/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 9:16/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-DZK.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-DZK.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-DZK.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-DZK.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-DZK.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-DZK.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-DZK.optimize] start with 53 nodes, 28 initializers, 62 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 1/62 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 2/62 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 3/62 - P0 - CastPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 4/62 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 5/62 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 6/62 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 7/62 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 8/62 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 9/62 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 10/62 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 11/62 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 12/62 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 13/62 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 14/62 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 15/62 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 16/62 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 17/62 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 18/62 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 19/62 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 20/62 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 21/62 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 22/62 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 23/62 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 24/62 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 25/62 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 26/62 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 27/62 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 28/62 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 29/62 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 30/62 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 31/62 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 32/62 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 33/62 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 34/62 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 35/62 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 36/62 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 37/62 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 38/62 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 39/62 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 40/62 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 41/62 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 42/62 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 43/62 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 44/62 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 45/62 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 46/62 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 47/62 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 48/62 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 49/62 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 50/62 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 51/62 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 52/62 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 53/62 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 54/62 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 55/62 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 56/62 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 57/62 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 58/62 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 59/62 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 60/62 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 61/62 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-DZK.optimize] use pattern 62/62 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-DZK.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-DZK.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.008 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-DZK.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-DZK.optimize] increase priority to 1
[GraphBuilderPatternOptimization-DZK.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-DZK.optimize] applies 5 matches, 2*LayerNormalizationPattern, 3*MatMulAddPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-DZK.optimize] iteration 3: 38 nodes, priority=1
[GraphBuilderPatternOptimization-DZK.optimize] applies 3 matches, 3*GemmTransposePattern - time=0.003 | max_time=MatMulAddPattern:0.000
[GraphBuilderPatternOptimization-DZK.optimize] iteration 4: 41 nodes, priority=1
[GraphBuilderPatternOptimization-DZK.optimize] applies 1 matches, [0]=MatchResult: SwitchReshapeActivationPattern replaces ['Gemm', 'Reshape', 'Relu'] - time=0.003 | max_time=LeakyReluPattern:0.000
[GraphBuilderPatternOptimization-DZK.optimize] iteration 5: 41 nodes, priority=1
[GraphBuilderPatternOptimization-DZK.optimize] increase priority to 2
[GraphBuilderPatternOptimization-DZK.optimize] iteration 6: 41 nodes, priority=2
[GraphBuilderPatternOptimization-DZK.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=FusedMatMulPattern:0.000
[GraphBuilderPatternOptimization-DZK.optimize] iteration 7: 37 nodes, priority=2
[GraphBuilderPatternOptimization-DZK.optimize] increase priority to 3
[GraphBuilderPatternOptimization-DZK.optimize] iteration 8: 37 nodes, priority=3
[GraphBuilderPatternOptimization-DZK.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-DZK.optimize] done after 9 iterations with 37 nodes in 0.047
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00026873699971474707
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.000461276000351063
STAT apply_GemmTransposePattern +6 -3 #it=1 maxmatch=2 i=3 - time=0.00047201799952745205
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0003997190015070373
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0006134300010671723
STAT apply_MatMulAddPattern +9 -6 #it=1 maxmatch=4 i=3 - time=0.001205105999360967
STAT apply_SwitchReshapeActivationPattern +3 -3 #it=1 maxmatch=0 i=1 - time=0.0003499000004012487
STAT build_graph_for_pattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.0019760269997277646
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00015775199972267728
STAT check_pattern_A0 +0 -0 #it=5 maxmatch=0 i=0 - time=0.002177924999159586
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.00039020900021569105
STAT match_BatchNormalizationPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.000402683002903359
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.0003800819986281567
STAT match_BiasGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00022643900047114585
STAT match_BiasSoftmaxPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00038076700002420694
STAT match_CastCastBinaryPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0004781300003742217
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0003423549987928709
STAT match_CastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0005399720002969843
STAT match_CastPattern +0 -0 #it=9 maxmatch=2 i=2 - time=0.0003718759999173926
STAT match_ClipClipPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00023752000015520025
STAT match_ComputationCastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0003541289997883723
STAT match_ConvBiasNullPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.00038812999900983414
STAT match_DropoutPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00020776000019395724
STAT match_ExpandBroadcastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0002217920009570662
STAT match_ExpandPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.0003259650020481786
STAT match_ExpandSwapPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0002154200001314166
STAT match_FastGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002262900006826385
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=9.403100011695642e-05
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00017092100006266264
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00035687299987330334
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.954999985784525e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.473099983850261e-05
STAT match_GeluErfPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00258755699996982
STAT match_GeluOrtPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0029185579987824894
STAT match_GeluPattern +0 -0 #it=9 maxmatch=2 i=0 - time=9.69600114331115e-06
STAT match_GemmTransposePattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.00042540700087556615
STAT match_IdentityPattern +0 -0 #it=9 maxmatch=6 i=4 - time=0.002341219001209538
STAT match_LayerNormalizationPattern +0 -0 #it=7 maxmatch=2 i=2 - time=0.0004568570002447814
STAT match_LayerNormalizationScalePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0002638449996084091
STAT match_LeakyReluPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.002840503000697936
STAT match_MatMulAddPattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.0008531129997209064
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0007692680001127883
STAT match_MulMulMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.000443069999164436
STAT match_MulMulMulScalarPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00029810700107191224
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003923499989468837
STAT match_QuickGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002318780007044552
STAT match_ReduceReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00032548000126553234
STAT match_ReduceSumNormalizePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00023054800021782285
STAT match_Reshape2Of3Pattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0005694909996236674
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.000410096999075904
STAT match_ReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0006643510005233111
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0003899179992004065
STAT match_ReshapeReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0005201469975872897
STAT match_RotaryConcatPartPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003012349998243735
STAT match_SameChildrenPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0009375750005347072
STAT match_SequenceConstructAtPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00023661499926674878
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00037342099949455587
STAT match_SliceSlicePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00021994299913785653
STAT match_SlicesSplitPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002885629992306349
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.004117469999073364
STAT match_SoftmaxGradPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002187569980378612
STAT match_SplitConcatPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00022333500055538025
STAT match_SqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0003555850016709883
STAT match_Sub1MulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002600290017653606
STAT match_SwitchOrderBinaryPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0005817310002385057
STAT match_SwitchReshapeActivationPattern +0 -0 #it=7 maxmatch=5 i=1 - time=0.0003461279993643984
STAT match_TransposeEqualReshapePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.000429359000918339
STAT match_TransposeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0007202190008683829
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0004694070003097295
STAT match_TransposeReshapeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00039989400102058426
STAT match_TransposeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0004221899998810841
STAT match_UnsqueezeEqualPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003309240000817226
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0003225450009267661
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0008952879998105345
--MODEL: 37 nodes, 1 inputs, 1 outputs, 34 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 29 x 1t
INIT: 5 x 7t
NODE: 3 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 3 x Gemm
NODE: 2 x LayerNormalization
NODE: 8 x MatMul
NODE: 1 x Relu
NODE: 6 x Reshape
NODE: 2 x Softmax
NODE: 3 x Transpose
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
--MODEL: 37 nodes, 1 inputs, 1 outputs, 34 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
INIT: 3 x 7t[2]
INIT: 1 x 7t[3]
NODE: 3 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x Gemm -SIG- 1t[30x128], 1t[16x128], 1t[16]
NODE: 1 x Gemm -SIG- 1t[30x16], 1t[128x16], 1t[128]
NODE: 1 x Gemm -SIG- 1t[30x32], 1t[16x32], 1t[16]
NODE: 2 x LayerNormalization -SIG- 1t[1x30x16], 1t[16], 1t[16]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x Relu -SIG- 1t[30x128]
NODE: 1 x Reshape -SIG- 1t[1x30x128], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x16], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x32], 7t[2]
NODE: 1 x Reshape -SIG- 1t[30x128], 7t[3]
NODE: 2 x Reshape -SIG- 1t[30x16], 7t[3]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 1 x Transpose -SIG- 1t[128x16]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[16x128]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[32x16]-perm=1;0
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
[GraphBuilder-DZK.remove_unused] remove_initializer 1:2/34:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 2:3/34:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 3:5/34:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 4:6/34:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 5:9/34:init7_s1_-1:int64[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 6:10/34:init1_s1_:float32[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 7:11/34:init1_s1_2:float32[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 8:16/34:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 9:22/34:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-DZK.remove_unused] remove_initializer 1:16/28:_onx_transpose_p_decoder_attention_linear_weight0:torch.float32[torch.Size([32, 16])]
[GraphBuilder-DZK.remove_unused] remove_initializer 2:17/28:_onx_transpose_p_decoder_feed_forward_linear_1_weight0:torch.float32[torch.Size([16, 128])]
[GraphBuilder-DZK.remove_unused] remove_initializer 3:18/28:_onx_transpose_p_decoder_feed_forward_linear_2_weight0:torch.float32[torch.Size([128, 16])]
[GraphBuilder-DZK.optimize] done with 34 nodes in 0.059
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1
init: name='init7_s3_1_30_-1' type=int64 shape=(3,) -- array([ 1, 30, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1
init: name='init7_s2_-1_128' type=int64 shape=(2,) -- array([ -1, 128])-- MatMulAddPattern.new_shape.1
init: name='GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(16, 32)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_attention_linear_weight0)##_onx_transpose_p_decoder_attention_linear_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_1_weight0)##_onx_transpose_p_decoder_feed_forward_linear_1_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_2_weight0)##_onx_transpose_p_decoder_feed_forward_linear_2_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add00
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0, decoder.attention.linear.bias, transB=1) -> MatMulAddPattern--cat2
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_100
Reshape(_onx_div_sub_add_100, init7_s2_-1_16) -> MatMulAddPattern--_onx_div_sub_add_100
Gemm(MatMulAddPattern--_onx_div_sub_add_100, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0, decoder.feed_forward.linear_1.bias, transB=1) -> SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002
Relu(SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002) -> SwitchReshapeActivationPatternL_linear_7
Reshape(SwitchReshapeActivationPatternL_linear_7, init7_s3_1_30_-1) -> relu
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0, decoder.feed_forward.linear_2.bias, transB=1) -> MatMulAddPattern--relu2
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 86 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 86 results (first model)
[compare_onnx_execution] got 61 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 88 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 ~ | INITIA float32 AAAA in | INITIA float32 2:16x16 BBAA _o
004 ~ | INITIA int64 1:1 BAAA in | INITIA float32 2:16x16 ZAYZ _o
005 ~ | INITIA int64 1:1 AAAA in | INITIA float32 2:16x16 AZAA _o
006 ~ | INITIA int64 1:1 EAAA in | INITIA float32 2:30x30 KGSP sl
007 ~ | INITIA float32 AAAA in | INITIA float32 1:1 AAAA _r
008 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 2:16x16 ZBBB _o
009 ~ | INITIA float32 1:16 EEEE in | INITIA float32 2:16x16 AAAA _o
010 ~ | INITIA float32 1:16 AAAA in | INITIA float32 2:16x16 ABAZ _o
011 ~ | INITIA int64 1:2 ZGAA in | INITIA float32 2:30x30 KGSP sl
012 ~ | INITIA int64 1:3 BEZA in | INITIA float32 1:1 AAAA _r
013 ~ | INITIA int64 1:2 ZQAA in | INITIA float32 1:16 EEEE in
014 ~ | INITIA int64 1:2 ZYAA in | INITIA float32 1:16 AAAA in
015 - | INITIA float32 2:1024x16 PUEG em |
016 - | INITIA float32 2:1024x16 UCNT em |
017 ~ | INITIA float32 2:16x16 AABB de | INITIA int64 1:2 ZGAA in
018 ~ | INITIA float32 2:16x16 AYZZ de | INITIA int64 1:3 BEZA in
019 ~ | INITIA float32 2:16x16 AYAA de | INITIA int64 1:2 ZQAA in
020 ~ | INITIA float32 2:16x16 ABBA de | INITIA int64 1:2 ZYAA in
021 ~ | INITIA float32 2:16x16 CAAZ de | INITIA float32 2:16x32 BZAA Ge
022 ~ | INITIA float32 2:16x16 AAAA de | INITIA float32 2:128x16 FAAZ Ge
023 ~ | INITIA float32 2:16x32 BZAA de | INITIA float32 2:16x128 AZBA Ge
024 + | | INITIA float32 2:1024x16 PUEG em
025 + | | INITIA float32 2:1024x16 UCNT em
026 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
027 - | INITIA float32 2:128x16 FAAZ de |
028 = | INITIA float32 1:128 AAAA de | INITIA float32 1:128 AAAA de
029 - | INITIA float32 2:16x128 AZBA de |
030 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
031 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
032 = | RESULT float32 3:1x30x16 SAPX Gather em | RESULT float32 3:1x30x16 SAPX Gather em
033 = | RESULT float32 3:1x30x16 HRCB Gather em | RESULT float32 3:1x30x16 HRCB Gather em
034 = | RESULT float32 3:1x30x16 AQRY Add ad | RESULT float32 3:1x30x16 AQRY Add ad
035 = | RESULT float32 3:1x30x16 CYXD LayerNormalizat _o | RESULT float32 3:1x30x16 CYXD LayerNormalizat _o
036 - | RESULT float32 2:16x16 BBAA Transpose _o |
037 = | RESULT float32 3:1x30x16 FDJH MatMul li | RESULT float32 3:1x30x16 FDJH MatMul li
038 - | RESULT float32 2:16x16 ZAYZ Transpose _o |
039 = | RESULT float32 3:1x30x16 JSEH MatMul li | RESULT float32 3:1x30x16 JSEH MatMul li
040 - | RESULT float32 2:16x16 AZAA Transpose _o |
041 - | RESULT float32 3:1x30x16 YOWB MatMul li |
042 - | RESULT float32 3:1x16x30 CNKN Transpose tr |
043 ~ | RESULT float32 3:1x30x30 PTGR MatMul ma | RESULT float32 3:1x30x30 KZBK FusedMatMul _o
044 - | RESULT float32 1:1 AAAA Reshape _r |
045 ~ | RESULT float32 3:1x30x30 KZBK Mul _o | RESULT float32 3:1x30x16 YOWB MatMul li
046 - | RESULT int64 1:2 AAAA Concat Sl |
047 - | RESULT int64 1:2 EEAA Concat Sl |
048 - | RESULT int64 1:2 ABAA Concat Sl |
049 - | RESULT float32 2:30x30 KGSP Slice sl |
050 - | RESULT float32 1:1 AAAA Reshape _r |
051 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
052 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
053 = | RESULT float32 3:1x30x30 IGHH Softmax so | RESULT float32 3:1x30x30 IGHH Softmax so
054 = | RESULT float32 3:1x30x16 UVVV MatMul ma | RESULT float32 3:1x30x16 UVVV MatMul ma
055 - | RESULT float32 2:16x16 ZBBB Transpose _o |
056 = | RESULT float32 3:1x30x16 YYYA MatMul li | RESULT float32 3:1x30x16 YYYA MatMul li
057 - | RESULT float32 2:16x16 AAAA Transpose _o |
058 = | RESULT float32 3:1x30x16 CLSZ MatMul li | RESULT float32 3:1x30x16 CLSZ MatMul li
059 - | RESULT float32 2:16x16 ABAZ Transpose _o |
060 - | RESULT float32 3:1x30x16 YGRX MatMul li |
061 - | RESULT float32 3:1x16x30 AFUE Transpose tr |
062 ~ | RESULT float32 3:1x30x30 VWCM MatMul ma | RESULT float32 3:1x30x30 ZZHD FusedMatMul _o
063 - | RESULT float32 1:1 AAAA Reshape _r |
064 ~ | RESULT float32 3:1x30x30 ZZHD Mul _o | RESULT float32 3:1x30x16 YGRX MatMul li
065 - | RESULT int64 1:2 AAAA Concat Sl |
066 - | RESULT int64 1:2 EEAA Concat Sl |
067 - | RESULT int64 1:2 ABAA Concat Sl |
068 - | RESULT float32 2:30x30 KGSP Slice sl |
069 - | RESULT float32 1:1 AAAA Reshape _r |
070 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
071 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
072 = | RESULT float32 3:1x30x30 IGHH Softmax so | RESULT float32 3:1x30x30 IGHH Softmax so
073 = | RESULT float32 3:1x30x16 AAAY MatMul ma | RESULT float32 3:1x30x16 AAAY MatMul ma
074 = | RESULT float32 3:1x30x32 UWVT Concat ca | RESULT float32 3:1x30x32 UWVT Concat ca
075 = | RESULT float32 2:30x32 UWVT Reshape Ma | RESULT float32 2:30x32 UWVT Reshape Ma
076 = | RESULT float32 2:30x16 GICD Gemm Ma | RESULT float32 2:30x16 GICD Gemm Ma
077 = | RESULT float32 3:1x30x16 GICD Reshape li | RESULT float32 3:1x30x16 GICD Reshape li
078 = | RESULT float32 3:1x30x16 GYUB Add ad | RESULT float32 3:1x30x16 GYUB Add ad
079 = | RESULT float32 3:1x30x16 BZXD LayerNormalizat _o | RESULT float32 3:1x30x16 BZXD LayerNormalizat _o
080 = | RESULT float32 2:30x16 BZXD Reshape Ma | RESULT float32 2:30x16 BZXD Reshape Ma
081 = | RESULT float32 2:30x128 HWMQ Gemm Sw | RESULT float32 2:30x128 HWMQ Gemm Sw
082 = | RESULT float32 2:30x128 DASA Relu Sw | RESULT float32 2:30x128 DASA Relu Sw
083 = | RESULT float32 3:1x30x128 DASA Reshape re | RESULT float32 3:1x30x128 DASA Reshape re
084 = | RESULT float32 2:30x128 DASA Reshape Ma | RESULT float32 2:30x128 DASA Reshape Ma
085 = | RESULT float32 2:30x16 ZAAA Gemm Ma | RESULT float32 2:30x16 ZAAA Gemm Ma
086 = | RESULT float32 3:1x30x16 ZAAA Reshape li | RESULT float32 3:1x30x16 ZAAA Reshape li
087 = | RESULT float32 3:1x30x16 EXUB Add ou | RESULT float32 3:1x30x16 EXUB Add ou
088 = | OUTPUT float32 3:1x30x16 EXUB ou | OUTPUT float32 3:1x30x16 EXUB ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 4.517 seconds)
Related examples

to_onnx and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant