Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.helpers import pretty_onnx, max_diff
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.494992733001709, max=6.392674922943115
First conversion to ONNX¶
The conversion relies on torch.export.export()
.
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
Then function to_onnx
converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init1_s_' type=float32 shape=() -- array([0.25], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2' type=float32 shape=() -- array([0.], dtype=float32)-- shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1
init: name='init7_s3_1_30_-1' type=int64 shape=(3,) -- array([ 1, 30, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1
init: name='init7_s2_-1_128' type=int64 shape=(2,) -- array([ -1, 128])-- MatMulAddPattern.new_shape.1
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.attention.0.query.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='decoder.attention.attention.0.key.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='decoder.attention.attention.0.value.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='decoder.attention.attention.1.query.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='decoder.attention.attention.1.key.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='decoder.attention.attention.1.value.weight' type=float32 shape=(16, 16)-- DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='decoder.attention.linear.weight' type=float32 shape=(16, 32)-- DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.weight' type=float32 shape=(128, 16)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.weight' type=float32 shape=(16, 128)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add00
Transpose(decoder.attention.attention.0.query.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_0_query_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
Transpose(decoder.attention.attention.0.key.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_0_key_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Transpose(decoder.attention.attention.0.value.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_0_value_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_0
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_20
Equal(slice_2, _reshape_init1_s_20) -> eq
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, linear_2) -> matmul_1
Transpose(decoder.attention.attention.1.query.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_1_query_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
Transpose(decoder.attention.attention.1.key.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_1_key_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Transpose(decoder.attention.attention.1.value.weight, perm=[1,0]) -> _onx_transpose_p_decoder_attention_attention_1_value_weight0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_02
Mul(matmul_2, _reshape_init1_s_02) -> _onx_mul_matmul_20
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start2
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end2
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis2
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_4
Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_202
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, decoder.attention.linear.weight, decoder.attention.linear.bias, transB=1) -> MatMulAddPattern--cat2
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_100
Reshape(_onx_div_sub_add_100, init7_s2_-1_16) -> MatMulAddPattern--_onx_div_sub_add_100
Gemm(MatMulAddPattern--_onx_div_sub_add_100, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias, transB=1) -> SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002
Relu(SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002) -> SwitchReshapeActivationPatternL_linear_7
Reshape(SwitchReshapeActivationPatternL_linear_7, init7_s3_1_30_-1) -> relu
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_2.bias, transB=1) -> MatMulAddPattern--relu2
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-4.494992733001709, max=6.392674922943115
max discrepancy=2.384185791015625e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='decoder.attention.attention.0.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.query.weight)
init: name='decoder.attention.attention.0.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.key.weight)
init: name='decoder.attention.attention.0.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.value.weight)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='decoder.attention.attention.1.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.query.weight)
init: name='decoder.attention.attention.1.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.key.weight)
init: name='decoder.attention.attention.1.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.value.weight)
init: name='decoder.attention.linear.weight' type=float32 shape=(16, 32)-- GraphBuilder.make_local_function/from(decoder.attention.linear.weight)
init: name='decoder.feed_forward.linear_1.weight' type=float32 shape=(128, 16)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.weight)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.weight' type=float32 shape=(16, 128)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_2.weight)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
__main__.DecoderLayer[aten_local_function](embedding, mask2, mask, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
Constant(value=[1.0, 1.0,...) -> init1_s16_
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'weight'
Transpose(weight, perm=[1,0]) -> _onx_transpose_weight0
MatMul(layer_norm, _onx_transpose_weight0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
Constant(value=0.25) -> init1_s_
Constant(value=[1]) -> init7_s1_1
Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_0
Constant(value=[0]) -> init7_s1_0
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30]) -> init7_s1_30
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Constant(value=0.0) -> init1_s_2
Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_20
Constant(value=[-inf]) -> init1_s1_
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.query.weight) -> query
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.key.weight) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.value.weight) -> value
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
Equal(slice_2, _reshape_init1_s_20) -> eq
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'cat'
input: 'weight'
Constant(value=[-0.081113...) -> bias
Constant(value=[-1, 32]) -> init7_s2_-1_32
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, weight, bias, transB=1) -> MatMulAddPattern--cat2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> output
Constant(value=[-0.081113...) -> decoder.attention.linear.bias
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask2'
input: 'mask'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
__main__.AttentionBlock[aten_local_function](layer_norm, mask, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, mask2, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
Linear_2[aten_local_function](cat, decoder.attention.linear.weight) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'weight'
input: 'bias'
Constant(value=[-1, 16]) -> init7_s2_-1_16
Reshape(layer_norm_1, init7_s2_-1_16) -> MatMulAddPattern--layer_norm_1
Gemm(MatMulAddPattern--layer_norm_1, weight, bias, transB=1) -> MatMulAddPattern--layer_norm_12
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--layer_norm_12, init7_s3_1_30_-1) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=Linear_2_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'relu'
input: 'weight'
Constant(value=[-0.009325...) -> bias
Constant(value=[-1, 128]) -> init7_s2_-1_128
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, weight, bias, transB=1) -> MatMulAddPattern--relu2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> output
Constant(value=[-0.009325...) -> decoder.feed_forward.linear_2.bias
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.feed_forward.linear_1.bias'
Linear_3[aten_local_function](layer_norm_1, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> linear_1
ReLU[aten_local_function](linear_1) -> relu
Linear_2_2[aten_local_function](relu, decoder.feed_forward.linear_2.weight) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
input: 'mask2'
input: 'mask'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
input: 'decoder.feed_forward.linear_1.bias'
LayerNorm[aten_local_function](add) -> norm_1
__main__.MultiAttentionBlock[aten_local_function](norm_1, mask2, mask, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention
Add(attention, add) -> add_1
LayerNorm[aten_local_function](add_1) -> norm_2
__main__.FeedForward[aten_local_function](norm_2, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> feed_forward
Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-4.494992733001709, max=6.392674922943115
max discrepancy=2.384185791015625e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.

Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask)
init: name='decoder.attention.attention.0.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.query.weight)
init: name='decoder.attention.attention.0.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.key.weight)
init: name='decoder.attention.attention.0.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.0.value.weight)
init: name='mask2' type=float32 shape=(256, 256) -- GraphBuilder.make_local_function/from(mask2)
init: name='decoder.attention.attention.1.query.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.query.weight)
init: name='decoder.attention.attention.1.key.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.key.weight)
init: name='decoder.attention.attention.1.value.weight' type=float32 shape=(16, 16)-- GraphBuilder.make_local_function/from(decoder.attention.attention.1.value.weight)
init: name='decoder.attention.linear.weight' type=float32 shape=(16, 32)-- GraphBuilder.make_local_function/from(decoder.attention.linear.weight)
init: name='decoder.feed_forward.linear_1.weight' type=float32 shape=(128, 16)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.weight)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.weight' type=float32 shape=(16, 128)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_2.weight)
Constant(value=[1]) -> init7_s1_1__11
Gather(embedding.embedding.weight, input_ids) -> embedding__1
Gather(embedding.pe.weight, input_ids) -> pe__1
Add(embedding__1, pe__1) -> embedding
Constant(value=[1.0, 1.0,...) -> init1_s16___5
Constant(value=[0.0, 0.0,...) -> init1_s16_2__5
LayerNormalization(embedding, init1_s16___5, init1_s16_2__5, axis=-1, epsilon=0.00, stash_type=1) -> norm_1__4
Constant(value=0.25) -> init1_s___7
Constant(value=[1]) -> init7_s1_1__7
Reshape(init1_s___7, init7_s1_1__7) -> _reshape_init1_s_0__7
Constant(value=[0]) -> init7_s1_0__7
Concat(init7_s1_0__7, init7_s1_0__7, axis=0) -> SliceSlicePattern_init7_s1_0_start__7
Constant(value=[30]) -> init7_s1_30__7
Concat(init7_s1_30__7, init7_s1_30__7, axis=0) -> SliceSlicePattern_init7_s1_30_end__7
Constant(value=0.0) -> init1_s_2__7
Reshape(init1_s_2__7, init7_s1_1__7) -> _reshape_init1_s_20__7
Constant(value=[-inf]) -> init1_s1___7
Transpose(decoder.attention.attention.0.query.weight, perm=[1,0]) -> _onx_transpose_weight0__8
MatMul(norm_1__4, _onx_transpose_weight0__8) -> query__7
Transpose(decoder.attention.attention.0.key.weight, perm=[1,0]) -> _onx_transpose_weight0__9
MatMul(norm_1__4, _onx_transpose_weight0__9) -> key__7
Transpose(key__7, perm=[0,2,1]) -> transpose__7
MatMul(query__7, transpose__7) -> matmul__7
Mul(matmul__7, _reshape_init1_s_0__7) -> _onx_mul_matmul0__7
Transpose(decoder.attention.attention.0.value.weight, perm=[1,0]) -> _onx_transpose_weight0__10
MatMul(norm_1__4, _onx_transpose_weight0__10) -> value__7
Concat(init7_s1_0__7, init7_s1_1__7, axis=0) -> SliceSlicePattern_init7_s1_1_axis__7
Slice(mask, SliceSlicePattern_init7_s1_0_start__7, SliceSlicePattern_init7_s1_30_end__7, SliceSlicePattern_init7_s1_1_axis__7) -> slice_2__7
Equal(slice_2__7, _reshape_init1_s_20__7) -> eq__7_1
Where(eq__7_1, init1_s1___7, _onx_mul_matmul0__7) -> masked_fill__7
Softmax(masked_fill__7, axis=-1) -> softmax__7
MatMul(softmax__7, value__7) -> attention_0__6
Constant(value=0.25) -> init1_s___11
Reshape(init1_s___11, init7_s1_1__11) -> _reshape_init1_s_0__11
Constant(value=[0]) -> init7_s1_0__11
Concat(init7_s1_0__11, init7_s1_0__11, axis=0) -> SliceSlicePattern_init7_s1_0_start__11
Constant(value=[30]) -> init7_s1_30__11
Concat(init7_s1_30__11, init7_s1_30__11, axis=0) -> SliceSlicePattern_init7_s1_30_end__11
Constant(value=0.0) -> init1_s_2__11
Reshape(init1_s_2__11, init7_s1_1__11) -> _reshape_init1_s_20__11
Constant(value=[-inf]) -> init1_s1___11
Transpose(decoder.attention.attention.1.query.weight, perm=[1,0]) -> _onx_transpose_weight0__12
MatMul(norm_1__4, _onx_transpose_weight0__12) -> query__11
Transpose(decoder.attention.attention.1.key.weight, perm=[1,0]) -> _onx_transpose_weight0__13
MatMul(norm_1__4, _onx_transpose_weight0__13) -> key__11
Transpose(key__11, perm=[0,2,1]) -> transpose__11
MatMul(query__11, transpose__11) -> matmul__11
Mul(matmul__11, _reshape_init1_s_0__11) -> _onx_mul_matmul0__11
Transpose(decoder.attention.attention.1.value.weight, perm=[1,0]) -> _onx_transpose_weight0__14
MatMul(norm_1__4, _onx_transpose_weight0__14) -> value__11
Concat(init7_s1_0__11, init7_s1_1__11, axis=0) -> SliceSlicePattern_init7_s1_1_axis__11
Slice(mask2, SliceSlicePattern_init7_s1_0_start__11, SliceSlicePattern_init7_s1_30_end__11, SliceSlicePattern_init7_s1_1_axis__11) -> slice_2__11
Equal(slice_2__11, _reshape_init1_s_20__11) -> eq__11_2
Where(eq__11_2, init1_s1___11, _onx_mul_matmul0__11) -> masked_fill__11
Softmax(masked_fill__11, axis=-1) -> softmax__11
MatMul(softmax__11, value__11) -> attention_1__6
Concat(attention_0__6, attention_1__6, axis=-1) -> cat__6_0
Constant(value=[-0.081113...) -> bias__15
Constant(value=[-1, 32]) -> init7_s2_-1_32__15
Reshape(cat__6_0, init7_s2_-1_32__15) -> MatMulAddPattern--cat__15
Gemm(MatMulAddPattern--cat__15, decoder.attention.linear.weight, bias__15, transB=1) -> MatMulAddPattern--cat2__15
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1__15
Reshape(MatMulAddPattern--cat2__15, init7_s3_1_30_-1__15) -> attention__4
Add(attention__4, embedding) -> add_1__4
Constant(value=[-0.081113...) -> decoder.attention.linear.bias__15
Constant(value=[1.0, 1.0,...) -> init1_s16___16
Constant(value=[0.0, 0.0,...) -> init1_s16_2__16
LayerNormalization(add_1__4, init1_s16___16, init1_s16_2__16, axis=-1, epsilon=0.00, stash_type=1) -> norm_2__4
Constant(value=[-1, 16]) -> init7_s2_-1_16__18
Reshape(norm_2__4, init7_s2_-1_16__18) -> MatMulAddPattern--layer_norm_1__18
Gemm(MatMulAddPattern--layer_norm_1__18, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias, transB=1) -> MatMulAddPattern--layer_norm_12__18
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1__18
Reshape(MatMulAddPattern--layer_norm_12__18, init7_s3_1_30_-1__18) -> linear_1__17
Relu(linear_1__17) -> relu__17
Constant(value=[-0.009325...) -> bias__20
Constant(value=[-1, 128]) -> init7_s2_-1_128__20
Reshape(relu__17, init7_s2_-1_128__20) -> MatMulAddPattern--relu__20
Gemm(MatMulAddPattern--relu__20, decoder.feed_forward.linear_2.weight, bias__20, transB=1) -> MatMulAddPattern--relu2__20
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1__20
Reshape(MatMulAddPattern--relu2__20, init7_s3_1_30_-1__20) -> feed_forward__4
Add(feed_forward__4, add_1__4) -> output_0
Constant(value=[-0.009325...) -> decoder.feed_forward.linear_2.bias__20
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2
),
)
print(pretty_onnx(onx_optimized))
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
[GraphBuilder-AUI.optimize] start with 73 nodes
[GraphBuilder-AUI.optimize] #patterns=63
[GraphBuilder-AUI.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 3:5/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 4:7/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 9:17/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 10:19/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 11:21/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-AUI.remove_unused] remove_initializer 12:23/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 13:25/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 14:27/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 15:29/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 16:31/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-AUI.remove_unused] remove_initializer 17:33/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-AUI.remove_unused] remove_initializer 18:35/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 1:4/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 2:5/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 3:6/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 4:7/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 5:8/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 6:9/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 7:10/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-AUI.remove_unused] remove_initializer 8:14/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 9:16/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-AUI.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-AUI.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-AUI.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-AUI.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-AUI.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-AUI.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-AUI.optimize] start with 53 nodes, 28 initializers, 63 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 1/63 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 2/63 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 3/63 - P0 - CastPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 4/63 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 5/63 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 6/63 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 7/63 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 8/63 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 9/63 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 10/63 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 11/63 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 12/63 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 13/63 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 14/63 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 15/63 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 16/63 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 17/63 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 18/63 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 19/63 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 20/63 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 21/63 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 22/63 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 23/63 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 24/63 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 25/63 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 26/63 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 27/63 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 28/63 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 29/63 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 30/63 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 31/63 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 32/63 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 33/63 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 34/63 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 35/63 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 36/63 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 37/63 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 38/63 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 39/63 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 40/63 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 41/63 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 42/63 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 43/63 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 44/63 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 45/63 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 46/63 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 47/63 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 48/63 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 49/63 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 50/63 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 51/63 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 52/63 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 53/63 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 54/63 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 55/63 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 56/63 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 57/63 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 58/63 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 59/63 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 60/63 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 61/63 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 62/63 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-AUI.optimize] use pattern 63/63 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-AUI.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-AUI.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.004 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-AUI.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-AUI.optimize] increase priority to 1
[GraphBuilderPatternOptimization-AUI.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-AUI.optimize] applies 5 matches, 2*LayerNormalizationPattern, 3*MatMulAddPattern - time=0.003 | max_time=SimplifiedLayerNormalizationPattern:0.000
[GraphBuilderPatternOptimization-AUI.optimize] iteration 3: 38 nodes, priority=1
[GraphBuilderPatternOptimization-AUI.optimize] applies 5 matches, 3*GemmTransposePattern, 2*SkipLayerNormalizationPattern - time=0.002 | max_time=LeakyReluPattern:0.000
[GraphBuilderPatternOptimization-AUI.optimize] iteration 4: 39 nodes, priority=1
[GraphBuilderPatternOptimization-AUI.optimize] applies 1 matches, [0]=MatchResult: SwitchReshapeActivationPattern replaces ['Gemm', 'Reshape', 'Relu'] - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-AUI.optimize] iteration 5: 39 nodes, priority=1
[GraphBuilderPatternOptimization-AUI.optimize] increase priority to 2
[GraphBuilderPatternOptimization-AUI.optimize] iteration 6: 39 nodes, priority=2
[GraphBuilderPatternOptimization-AUI.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=SoftmaxCrossEntropyLossCastPattern:0.000
[GraphBuilderPatternOptimization-AUI.optimize] iteration 7: 35 nodes, priority=2
[GraphBuilderPatternOptimization-AUI.optimize] increase priority to 3
[GraphBuilderPatternOptimization-AUI.optimize] iteration 8: 35 nodes, priority=3
[GraphBuilderPatternOptimization-AUI.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-AUI.optimize] done after 9 iterations with 35 nodes in 0.040
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00012172100105090067
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.000725833004253218
STAT apply_GemmTransposePattern +6 -3 #it=1 maxmatch=2 i=3 - time=0.00039411900070263073
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0001523140053905081
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.00041697999768075533
STAT apply_MatMulAddPattern +9 -6 #it=1 maxmatch=4 i=3 - time=0.0010484139966138173
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=4 i=2 - time=9.907500134431757e-05
STAT apply_SwitchReshapeActivationPattern +3 -3 #it=1 maxmatch=0 i=1 - time=0.0011121259994979482
STAT build_graph_for_pattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.002285908005433157
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=9.651100117480382e-05
STAT check_pattern_A0 +0 -0 #it=5 maxmatch=0 i=0 - time=0.0016396609971707221
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.0002223109986516647
STAT match_BatchNormalizationPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.00031334400409832597
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.00025375699988217093
STAT match_BiasGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00020232999304425903
STAT match_BiasSoftmaxPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.000182225998287322
STAT match_CastCastBinaryPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0004113289978704415
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00021253199884085916
STAT match_CastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0005489030045282561
STAT match_CastPattern +0 -0 #it=9 maxmatch=2 i=2 - time=0.0003119369976047892
STAT match_ClipClipPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00019729699852177873
STAT match_ComputationCastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0002948990004369989
STAT match_ConvBiasNullPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.0002426959945296403
STAT match_DropoutPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00016643099661450833
STAT match_ExpandBroadcastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00020741399566759355
STAT match_ExpandPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.0002338599988433998
STAT match_ExpandSwapPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0001757709978846833
STAT match_FastGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00019701799465110525
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00010753699825727381
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00020311600019340403
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004294680002203677
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.0002720299999054987
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00015074700058903545
STAT match_GeluErfPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0017460969938838389
STAT match_GeluOrtPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.002270511999086011
STAT match_GeluPattern +0 -0 #it=9 maxmatch=2 i=0 - time=6.5499953052494675e-06
STAT match_GemmTransposePattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.000497116991027724
STAT match_IdentityPattern +0 -0 #it=9 maxmatch=6 i=4 - time=0.0013756669977738056
STAT match_LayerNormalizationPattern +0 -0 #it=7 maxmatch=2 i=2 - time=0.00027545100238057785
STAT match_LayerNormalizationScalePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00018230999921797775
STAT match_LeakyReluPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.002426770002784906
STAT match_MatMulAddPattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.0007546460001321975
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.000803749993792735
STAT match_MulMulMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.000474083011795301
STAT match_MulMulMulScalarPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0002835019986378029
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003014980029547587
STAT match_QuickGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00030974100081948563
STAT match_ReduceReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0002395329975115601
STAT match_ReduceSumNormalizePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0001875000016298145
STAT match_Reshape2Of3Pattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00038779598980909213
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00036438400275073946
STAT match_ReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0005849339977430645
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00027456399766379036
STAT match_ReshapeReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0007279030032805167
STAT match_RotaryConcatPartPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00028065200240234844
STAT match_SameChildrenPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0007872379974287469
STAT match_SequenceConstructAtPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00023499500093748793
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00036171399915474467
STAT match_SkipLayerNormalizationPattern +0 -0 #it=7 maxmatch=5 i=2 - time=0.00024835799558786675
STAT match_SliceSlicePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00025403099789400585
STAT match_SlicesSplitPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003333430031489115
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.002861023003788432
STAT match_SoftmaxGradPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00022944000011193566
STAT match_SplitConcatPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00019994999820482917
STAT match_SqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00032750699028838426
STAT match_Sub1MulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002185280027333647
STAT match_SwitchOrderBinaryPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00034321300336159766
STAT match_SwitchReshapeActivationPattern +0 -0 #it=7 maxmatch=5 i=1 - time=0.00030888900073478
STAT match_TransposeEqualReshapePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00043849199937540106
STAT match_TransposeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0007170210010372102
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0004008799987786915
STAT match_TransposeReshapeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00036722300137626007
STAT match_TransposeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0003546459956851322
STAT match_UnsqueezeEqualPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003498419901006855
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00023904400586616248
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0004657120007323101
--MODEL: 35 nodes, 1 inputs, 1 outputs, 34 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 29 x 1t
INIT: 5 x 7t
NODE: 1 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 3 x Gemm
NODE: 8 x MatMul
NODE: 1 x Relu
NODE: 6 x Reshape
NODE: 2 x Softmax
NODE: 3 x Transpose
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 35 nodes, 1 inputs, 1 outputs, 34 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
INIT: 3 x 7t[2]
INIT: 1 x 7t[3]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x Gemm -SIG- 1t[30x128], 1t[16x128], 1t[16]
NODE: 1 x Gemm -SIG- 1t[30x16], 1t[128x16], 1t[128]
NODE: 1 x Gemm -SIG- 1t[30x32], 1t[16x32], 1t[16]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x Relu -SIG- 1t[30x128]
NODE: 1 x Reshape -SIG- 1t[1x30x128], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x16], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x32], 7t[2]
NODE: 1 x Reshape -SIG- 1t[30x128], 7t[3]
NODE: 2 x Reshape -SIG- 1t[30x16], 7t[3]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 1 x Transpose -SIG- 1t[128x16]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[16x128]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[32x16]-perm=1;0
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-AUI.remove_unused] remove_initializer 1:2/34:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 2:3/34:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 3:5/34:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 4:6/34:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 5:9/34:init7_s1_-1:int64[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 6:10/34:init1_s1_:float32[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 7:11/34:init1_s1_2:float32[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 8:16/34:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 9:22/34:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-AUI.remove_unused] remove_initializer 1:16/28:_onx_transpose_p_decoder_attention_linear_weight0:torch.float32[torch.Size([32, 16])]
[GraphBuilder-AUI.remove_unused] remove_initializer 2:17/28:_onx_transpose_p_decoder_feed_forward_linear_1_weight0:torch.float32[torch.Size([16, 128])]
[GraphBuilder-AUI.remove_unused] remove_initializer 3:18/28:_onx_transpose_p_decoder_feed_forward_linear_2_weight0:torch.float32[torch.Size([128, 16])]
[GraphBuilder-AUI.optimize] done with 32 nodes in 0.046
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1
init: name='init7_s3_1_30_-1' type=int64 shape=(3,) -- array([ 1, 30, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1
init: name='init7_s2_-1_128' type=int64 shape=(2,) -- array([ -1, 128])-- MatMulAddPattern.new_shape.1
init: name='GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(16, 32)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_attention_linear_weight0)##_onx_transpose_p_decoder_attention_linear_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_1_weight0)##_onx_transpose_p_decoder_feed_forward_linear_1_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_2_weight0)##_onx_transpose_p_decoder_feed_forward_linear_2_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add00, unused, unused2, add
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0, decoder.attention.linear.bias, transB=1) -> MatMulAddPattern--cat2
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_100, unused3, unused4, add_1
Reshape(_onx_div_sub_add_100, init7_s2_-1_16) -> MatMulAddPattern--_onx_div_sub_add_100
Gemm(MatMulAddPattern--_onx_div_sub_add_100, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0, decoder.feed_forward.linear_1.bias, transB=1) -> SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002
Relu(SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002) -> SwitchReshapeActivationPatternL_linear_7
Reshape(SwitchReshapeActivationPatternL_linear_7, init7_s3_1_30_-1) -> relu
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0, decoder.feed_forward.linear_2.bias, transB=1) -> MatMulAddPattern--relu2
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft]
which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16) -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.make_local_function/from(slice_4)
init: name='GemmTransposePattern--_onx_transpose_weight0' type=float32 shape=(16, 32)-- GraphBuilder.make_local_function/from(GemmTransposePattern--_onx_transpose_weight0)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='GemmTransposePattern--_onx_transpose_weight02' type=float32 shape=(128, 16)-- GraphBuilder.make_local_function/from(GemmTransposePattern--_onx_transpose_weight02)
init: name='GemmTransposePattern--_onx_transpose_weight022' type=float32 shape=(16, 128)-- GraphBuilder.make_local_function/from(GemmTransposePattern--_onx_transpose_weight022)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
__main__.DecoderLayer[aten_local_function](embedding, GemmTransposePattern--_onx_transpose_weight022, GemmTransposePattern--_onx_transpose_weight02, slice_4, slice_2, GemmTransposePattern--_onx_transpose_weight0, _onx_transpose_weight04, _onx_transpose_weight032, _onx_transpose_weight03, _onx_transpose_weight022, _onx_transpose_weight02, _onx_transpose_weight0, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
Constant(value=[1.0, 1.0,...) -> init1_s16_
Constant(value=[0.0, 0.0,...) -> init1_s16_2
LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: '_onx_transpose_weight0'
MatMul(layer_norm, _onx_transpose_weight0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_2'
input: '_onx_transpose_weight03'
input: '_onx_transpose_weight02'
input: '_onx_transpose_weight0'
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.0]) -> _reshape_init1_s_20
Equal(slice_2, _reshape_init1_s_20) -> eq
Linear[aten_local_function](layer_norm, _onx_transpose_weight0) -> query
Linear[aten_local_function](layer_norm, _onx_transpose_weight02) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
Linear[aten_local_function](layer_norm, _onx_transpose_weight03) -> value
MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'cat'
input: 'GemmTransposePattern--_onx_transpose_weight0'
Constant(value=[-0.081113...) -> bias
Constant(value=[-1, 32]) -> init7_s2_-1_32
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, GemmTransposePattern--_onx_transpose_weight0, bias, transB=1) -> MatMulAddPattern--cat2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> output
Constant(value=[-0.081113...) -> decoder.attention.linear.bias
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_4'
input: 'slice_2'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: '_onx_transpose_weight04'
input: '_onx_transpose_weight032'
input: '_onx_transpose_weight03'
input: '_onx_transpose_weight022'
input: '_onx_transpose_weight02'
input: '_onx_transpose_weight0'
__main__.AttentionBlock[aten_local_function](layer_norm, slice_2, _onx_transpose_weight03, _onx_transpose_weight02, _onx_transpose_weight0) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, slice_4, _onx_transpose_weight032, _onx_transpose_weight022, _onx_transpose_weight04) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
Linear_2[aten_local_function](cat, GemmTransposePattern--_onx_transpose_weight0) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: 'bias'
Constant(value=[-1, 16]) -> init7_s2_-1_16
Reshape(layer_norm_1, init7_s2_-1_16) -> MatMulAddPattern--layer_norm_1
Gemm(MatMulAddPattern--layer_norm_1, GemmTransposePattern--_onx_transpose_weight0, bias, transB=1) -> MatMulAddPattern--layer_norm_12
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--layer_norm_12, init7_s3_1_30_-1) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=Linear_2_2 domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'relu'
input: 'GemmTransposePattern--_onx_transpose_weight0'
Constant(value=[-0.009325...) -> bias
Constant(value=[-1, 128]) -> init7_s2_-1_128
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, GemmTransposePattern--_onx_transpose_weight0, bias, transB=1) -> MatMulAddPattern--relu2
Constant(value=[1, 30, -1...) -> init7_s3_1_30_-1
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> output
Constant(value=[-0.009325...) -> decoder.feed_forward.linear_2.bias
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: 'GemmTransposePattern--_onx_transpose_weight02'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: 'decoder.feed_forward.linear_1.bias'
Linear_3[aten_local_function](layer_norm_1, GemmTransposePattern--_onx_transpose_weight0, decoder.feed_forward.linear_1.bias) -> linear_1
ReLU[aten_local_function](linear_1) -> relu
Linear_2_2[aten_local_function](relu, GemmTransposePattern--_onx_transpose_weight02) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'add'
input: 'GemmTransposePattern--_onx_transpose_weight022'
input: 'GemmTransposePattern--_onx_transpose_weight02'
input: 'slice_4'
input: 'slice_2'
input: 'GemmTransposePattern--_onx_transpose_weight0'
input: '_onx_transpose_weight04'
input: '_onx_transpose_weight032'
input: '_onx_transpose_weight03'
input: '_onx_transpose_weight022'
input: '_onx_transpose_weight02'
input: '_onx_transpose_weight0'
input: 'decoder.feed_forward.linear_1.bias'
LayerNorm[aten_local_function](add) -> norm_1
__main__.MultiAttentionBlock[aten_local_function](norm_1, slice_4, slice_2, GemmTransposePattern--_onx_transpose_weight0, _onx_transpose_weight04, _onx_transpose_weight032, _onx_transpose_weight03, _onx_transpose_weight022, _onx_transpose_weight02, _onx_transpose_weight0) -> attention
Add(attention, add) -> add_1
LayerNorm[aten_local_function](add_1) -> norm_2
__main__.FeedForward[aten_local_function](norm_2, GemmTransposePattern--_onx_transpose_weight022, GemmTransposePattern--_onx_transpose_weight02, decoder.feed_forward.linear_1.bias) -> feed_forward
Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
[GraphBuilder-VPW.optimize] start with 73 nodes
[GraphBuilder-VPW.optimize] #patterns=63
[GraphBuilder-VPW.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 3:5/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 4:7/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 9:17/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 10:19/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 11:21/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-VPW.remove_unused] remove_initializer 12:23/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 13:25/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 14:27/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 15:29/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 16:31/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-VPW.remove_unused] remove_initializer 17:33/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VPW.remove_unused] remove_initializer 18:35/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 1:4/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 2:5/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 3:6/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 4:7/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 5:8/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 6:9/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 7:10/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-VPW.remove_unused] remove_initializer 8:14/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 9:16/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VPW.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-VPW.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-VPW.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-VPW.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-VPW.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-VPW.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-VPW.optimize] start with 53 nodes, 28 initializers, 63 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 1/63 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 2/63 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 3/63 - P0 - CastPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 4/63 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 5/63 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 6/63 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 7/63 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 8/63 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 9/63 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 10/63 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 11/63 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 12/63 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 13/63 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 14/63 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 15/63 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 16/63 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 17/63 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 18/63 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 19/63 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 20/63 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 21/63 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 22/63 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 23/63 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 24/63 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 25/63 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 26/63 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 27/63 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 28/63 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 29/63 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 30/63 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 31/63 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 32/63 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 33/63 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 34/63 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 35/63 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 36/63 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 37/63 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 38/63 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 39/63 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 40/63 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 41/63 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 42/63 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 43/63 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 44/63 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 45/63 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 46/63 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 47/63 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 48/63 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 49/63 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 50/63 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 51/63 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 52/63 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 53/63 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 54/63 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 55/63 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 56/63 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 57/63 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 58/63 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 59/63 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 60/63 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 61/63 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 62/63 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-VPW.optimize] use pattern 63/63 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-VPW.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-VPW.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-VPW.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-VPW.optimize] increase priority to 1
[GraphBuilderPatternOptimization-VPW.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-VPW.optimize] applies 5 matches, 2*LayerNormalizationPattern, 3*MatMulAddPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VPW.optimize] iteration 3: 38 nodes, priority=1
[GraphBuilderPatternOptimization-VPW.optimize] applies 5 matches, 3*GemmTransposePattern, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=LeakyReluPattern:0.000
[GraphBuilderPatternOptimization-VPW.optimize] iteration 4: 39 nodes, priority=1
[GraphBuilderPatternOptimization-VPW.optimize] applies 1 matches, [0]=MatchResult: SwitchReshapeActivationPattern replaces ['Gemm', 'Reshape', 'Relu'] - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-VPW.optimize] iteration 5: 39 nodes, priority=1
[GraphBuilderPatternOptimization-VPW.optimize] increase priority to 2
[GraphBuilderPatternOptimization-VPW.optimize] iteration 6: 39 nodes, priority=2
[GraphBuilderPatternOptimization-VPW.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=FusedMatMulPattern:0.000
[GraphBuilderPatternOptimization-VPW.optimize] iteration 7: 35 nodes, priority=2
[GraphBuilderPatternOptimization-VPW.optimize] increase priority to 3
[GraphBuilderPatternOptimization-VPW.optimize] iteration 8: 35 nodes, priority=3
[GraphBuilderPatternOptimization-VPW.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-VPW.optimize] done after 9 iterations with 35 nodes in 0.057
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00021883299996261485
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007877910029492341
STAT apply_GemmTransposePattern +6 -3 #it=1 maxmatch=2 i=3 - time=0.0010516260044823866
STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0002963299957627896
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.00044749800144927576
STAT apply_MatMulAddPattern +9 -6 #it=1 maxmatch=4 i=3 - time=0.002857441002561245
STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=4 i=2 - time=0.00021683299928554334
STAT apply_SwitchReshapeActivationPattern +3 -3 #it=1 maxmatch=0 i=1 - time=0.0011760509987652767
STAT build_graph_for_pattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.003212256993720075
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011221899694646709
STAT check_pattern_A0 +0 -0 #it=5 maxmatch=0 i=0 - time=0.003005019992997404
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.00043266099964967
STAT match_BatchNormalizationPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.0005189069925108925
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=9 maxmatch=0 i=0 - time=0.00036961500154575333
STAT match_BiasGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00040071499461191706
STAT match_BiasSoftmaxPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0006103839914430864
STAT match_CastCastBinaryPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0006157509997137822
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0003368849938851781
STAT match_CastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0006655889992543962
STAT match_CastPattern +0 -0 #it=9 maxmatch=2 i=2 - time=0.0003831169960903935
STAT match_ClipClipPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00030735699692741036
STAT match_ComputationCastOpCastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0004056609977851622
STAT match_ConvBiasNullPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.0003919849987141788
STAT match_DropoutPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.00025196000206051394
STAT match_ExpandBroadcastPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0003357439964020159
STAT match_ExpandPattern +0 -0 #it=9 maxmatch=2 i=0 - time=0.0003591550012060907
STAT match_ExpandSwapPattern +0 -0 #it=7 maxmatch=0 i=0 - time=0.0003357900059199892
STAT match_FastGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00030285000320873223
STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00019926100503653288
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00045169200166128576
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0008860710040607955
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00025623900000937283
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00040981400161399506
STAT match_GeluErfPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.002260344994283514
STAT match_GeluOrtPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.002511578997655306
STAT match_GeluPattern +0 -0 #it=9 maxmatch=2 i=0 - time=1.3376000424614176e-05
STAT match_GemmTransposePattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.00039882200144347735
STAT match_IdentityPattern +0 -0 #it=9 maxmatch=6 i=4 - time=0.0021071279952593613
STAT match_LayerNormalizationPattern +0 -0 #it=7 maxmatch=2 i=2 - time=0.00040740500116953626
STAT match_LayerNormalizationScalePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00026684199838200584
STAT match_LeakyReluPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.003287355000793468
STAT match_MatMulAddPattern +0 -0 #it=7 maxmatch=5 i=3 - time=0.0009314250091847498
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0007469700067304075
STAT match_MulMulMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0005072530038887635
STAT match_MulMulMulScalarPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0003926520003005862
STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0006269189943850506
STAT match_QuickGeluPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00041890099601005204
STAT match_ReduceReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00026705100026447326
STAT match_ReduceSumNormalizePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00019329399219714105
STAT match_Reshape2Of3Pattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00043636800546664745
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.0003798319958150387
STAT match_ReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.000810945999546675
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=7 maxmatch=2 i=0 - time=0.00029445800464600325
STAT match_ReshapeReshapePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0006977580014790874
STAT match_RotaryConcatPartPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002928560024884064
STAT match_SameChildrenPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0009037410018208902
STAT match_SequenceConstructAtPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00027082100132247433
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00048583199895801954
STAT match_SkipLayerNormalizationPattern +0 -0 #it=7 maxmatch=5 i=2 - time=0.0004164519996265881
STAT match_SliceSlicePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00022787000125390477
STAT match_SlicesSplitPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003094139974564314
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0029770590081170667
STAT match_SoftmaxGradPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00030457100001513027
STAT match_SplitConcatPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0002754680026555434
STAT match_SqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0004329880030127242
STAT match_Sub1MulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0003434079990256578
STAT match_SwitchOrderBinaryPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00047474099483224563
STAT match_SwitchReshapeActivationPattern +0 -0 #it=7 maxmatch=5 i=1 - time=0.0005965789969195612
STAT match_TransposeEqualReshapePattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0006059559964342043
STAT match_TransposeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.00098943299963139
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.000566988001082791
STAT match_TransposeReshapeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0011015459967893548
STAT match_TransposeTransposePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.0005388229983509518
STAT match_UnsqueezeEqualPattern +0 -0 #it=7 maxmatch=5 i=0 - time=0.0005577259980782401
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=9 maxmatch=6 i=0 - time=0.00040844199611456133
STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0009799229992495384
--MODEL: 35 nodes, 1 inputs, 1 outputs, 34 initializers--
INPUT: 1 x 7t
INPUT-SEQ: 1 x Falset
OUTPUT: 1 x 1t
OUTPUT-SEQ: 1 x Falset
INIT: 29 x 1t
INIT: 5 x 7t
NODE: 1 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 3 x Gemm
NODE: 8 x MatMul
NODE: 1 x Relu
NODE: 6 x Reshape
NODE: 2 x Softmax
NODE: 3 x Transpose
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
NODE: 2 x com.microsoft.SkipLayerNormalization
--MODEL: 35 nodes, 1 inputs, 1 outputs, 34 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
INIT: 3 x 7t[2]
INIT: 1 x 7t[3]
NODE: 1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 1 x Gemm -SIG- 1t[30x128], 1t[16x128], 1t[16]
NODE: 1 x Gemm -SIG- 1t[30x16], 1t[128x16], 1t[128]
NODE: 1 x Gemm -SIG- 1t[30x32], 1t[16x32], 1t[16]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x Relu -SIG- 1t[30x128]
NODE: 1 x Reshape -SIG- 1t[1x30x128], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x16], 7t[2]
NODE: 1 x Reshape -SIG- 1t[1x30x32], 7t[2]
NODE: 1 x Reshape -SIG- 1t[30x128], 7t[3]
NODE: 2 x Reshape -SIG- 1t[30x16], 7t[3]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 1 x Transpose -SIG- 1t[128x16]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[16x128]-perm=1;0
NODE: 1 x Transpose -SIG- 1t[32x16]-perm=1;0
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-VPW.remove_unused] remove_initializer 1:2/34:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 2:3/34:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 3:5/34:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 4:6/34:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 5:9/34:init7_s1_-1:int64[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 6:10/34:init1_s1_:float32[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 7:11/34:init1_s1_2:float32[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 8:16/34:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 9:22/34:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-VPW.remove_unused] remove_initializer 1:16/28:_onx_transpose_p_decoder_attention_linear_weight0:torch.float32[torch.Size([32, 16])]
[GraphBuilder-VPW.remove_unused] remove_initializer 2:17/28:_onx_transpose_p_decoder_feed_forward_linear_1_weight0:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VPW.remove_unused] remove_initializer 3:18/28:_onx_transpose_p_decoder_feed_forward_linear_2_weight0:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VPW.optimize] done with 32 nodes in 0.067
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30) -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s16_' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1
init: name='init7_s3_1_30_-1' type=int64 shape=(3,) -- array([ 1, 30, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_-1_16' type=int64 shape=(2,) -- array([-1, 16]) -- MatMulAddPattern.new_shape.1
init: name='init7_s2_-1_128' type=int64 shape=(2,) -- array([ -1, 128])-- MatMulAddPattern.new_shape.1
init: name='GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(16, 32)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_attention_linear_weight0)##_onx_transpose_p_decoder_attention_linear_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_1_weight0)##_onx_transpose_p_decoder_feed_forward_linear_1_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(_onx_transpose_p_decoder_feed_forward_linear_2_weight0)##_onx_transpose_p_decoder_feed_forward_linear_2_weight0/GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,) -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add00, unused, unused2, add
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
Softmax(masked_fill_1, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Reshape(cat, init7_s2_-1_32) -> MatMulAddPattern--cat
Gemm(MatMulAddPattern--cat, GemmTransposePattern--_onx_transpose_p_decoder_attention_linear_weight0, decoder.attention.linear.bias, transB=1) -> MatMulAddPattern--cat2
Reshape(MatMulAddPattern--cat2, init7_s3_1_30_-1) -> linear_6
SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_100, unused3, unused4, add_1
Reshape(_onx_div_sub_add_100, init7_s2_-1_16) -> MatMulAddPattern--_onx_div_sub_add_100
Gemm(MatMulAddPattern--_onx_div_sub_add_100, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_1_weight0, decoder.feed_forward.linear_1.bias, transB=1) -> SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002
Relu(SwitchReshapeActivationPatternL_MatMulAddPattern--_onx_div_sub_add_1002) -> SwitchReshapeActivationPatternL_linear_7
Reshape(SwitchReshapeActivationPatternL_linear_7, init7_s3_1_30_-1) -> relu
Reshape(relu, init7_s2_-1_128) -> MatMulAddPattern--relu
Gemm(MatMulAddPattern--relu, GemmTransposePattern--_onx_transpose_p_decoder_feed_forward_linear_2_weight0, decoder.feed_forward.linear_2.bias, transB=1) -> MatMulAddPattern--relu2
Reshape(MatMulAddPattern--relu2, init7_s3_1_30_-1) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
res1, res2, align, dc = compare_onnx_execution(
onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 86 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 86 results (first model)
[compare_onnx_execution] got 65 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 88 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32 2:256x256 AOCQ b_ | INITIA float32 1:1 ?AAA in
002 - | INITIA float32 2:256x256 AOCQ b_ |
003 ~ | INITIA float32 AAAA in | INITIA float32 2:16x16 AAZA _o
004 ~ | INITIA int64 1:1 BAAA in | INITIA float32 2:16x16 ADAA _o
005 ~ | INITIA int64 1:1 AAAA in | INITIA float32 2:16x16 ACAA _o
006 ~ | INITIA int64 1:1 EAAA in | INITIA float32 2:30x30 KGSP sl
007 ~ | INITIA float32 AAAA in | INITIA float32 1:1 AAAA _r
008 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 2:16x16 AAAA _o
009 ~ | INITIA float32 1:16 EEEE in | INITIA float32 2:16x16 AAAA _o
010 ~ | INITIA float32 1:16 AAAA in | INITIA float32 2:16x16 AAAZ _o
011 ~ | INITIA int64 1:2 ZGAA in | INITIA float32 2:30x30 KGSP sl
012 ~ | INITIA int64 1:3 BEZA in | INITIA float32 1:1 AAAA _r
013 ~ | INITIA int64 1:2 ZQAA in | INITIA float32 1:16 EEEE in
014 ~ | INITIA int64 1:2 ZYAA in | INITIA float32 1:16 AAAA in
015 - | INITIA float32 2:1024x16 DIXO em |
016 - | INITIA float32 2:1024x16 QKQF em |
017 ~ | INITIA float32 2:16x16 ZYAA de | INITIA int64 1:2 ZGAA in
018 ~ | INITIA float32 2:16x16 AAAC de | INITIA int64 1:3 BEZA in
019 ~ | INITIA float32 2:16x16 BAZB de | INITIA int64 1:2 ZQAA in
020 ~ | INITIA float32 2:16x16 AZBA de | INITIA int64 1:2 ZYAA in
021 ~ | INITIA float32 2:16x16 ABAY de | INITIA float32 2:16x32 ABAA Ge
022 ~ | INITIA float32 2:16x16 ZBZA de | INITIA float32 2:128x16 YAXY Ge
023 ~ | INITIA float32 2:16x32 ABAA de | INITIA float32 2:16x128 AAAA Ge
024 + | | INITIA float32 2:1024x16 DIXO em
025 + | | INITIA float32 2:1024x16 QKQF em
026 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
027 - | INITIA float32 2:128x16 YAXY de |
028 = | INITIA float32 1:128 AYBB de | INITIA float32 1:128 AYBB de
029 - | INITIA float32 2:16x128 AAAA de |
030 = | INITIA float32 1:16 AAAA de | INITIA float32 1:16 AAAA de
031 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
032 = | RESULT float32 3:1x30x16 CRMY Gather em | RESULT float32 3:1x30x16 CRMY Gather em
033 = | RESULT float32 3:1x30x16 FYGL Gather em | RESULT float32 3:1x30x16 FYGL Gather em
034 - | RESULT float32 3:1x30x16 IOSJ Add ad |
035 ~ | RESULT float32 3:1x30x16 AAAA LayerNormalizat _o | RESULT float32 3:1x30x16 AAAA SkipLayerNormal _o
036 ~ | RESULT float32 2:16x16 AAZA Transpose _o | RESULT float32 3:1x30x1 AABZ SkipLayerNormal un
037 - | RESULT float32 3:1x30x16 XOHA MatMul li |
038 ~ | RESULT float32 2:16x16 ADAA Transpose _o | RESULT float32 3:1x30x1 FGGE SkipLayerNormal un
039 ~ | RESULT float32 3:1x30x16 AGDJ MatMul li | RESULT float32 3:1x30x16 IOSJ SkipLayerNormal ad
040 - | RESULT float32 2:16x16 ACAA Transpose _o |
041 ~ | RESULT float32 3:1x30x16 WAZB MatMul li | RESULT float32 3:1x30x16 XOHA MatMul li
042 ~ | RESULT float32 3:1x16x30 DBJF Transpose tr | RESULT float32 3:1x30x16 AGDJ MatMul li
043 ~ | RESULT float32 3:1x30x30 FRFO MatMul ma | RESULT float32 3:1x30x30 VSBX FusedMatMul _o
044 - | RESULT float32 1:1 AAAA Reshape _r |
045 ~ | RESULT float32 3:1x30x30 VSBX Mul _o | RESULT float32 3:1x30x16 WAZB MatMul li
046 - | RESULT int64 1:2 AAAA Concat Sl |
047 - | RESULT int64 1:2 EEAA Concat Sl |
048 - | RESULT int64 1:2 ABAA Concat Sl |
049 - | RESULT float32 2:30x30 KGSP Slice sl |
050 - | RESULT float32 1:1 AAAA Reshape _r |
051 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
052 = | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x30 ???? Where ma
053 = | RESULT float32 3:1x30x30 HHHH Softmax so | RESULT float32 3:1x30x30 HHHH Softmax so
054 = | RESULT float32 3:1x30x16 XXZA MatMul ma | RESULT float32 3:1x30x16 XXZA MatMul ma
055 - | RESULT float32 2:16x16 AAAA Transpose _o |
056 = | RESULT float32 3:1x30x16 CBCB MatMul li | RESULT float32 3:1x30x16 CBCB MatMul li
057 - | RESULT float32 2:16x16 AAAA Transpose _o |
058 = | RESULT float32 3:1x30x16 ACZD MatMul li | RESULT float32 3:1x30x16 ACZD MatMul li
059 - | RESULT float32 2:16x16 AAAZ Transpose _o |
060 ~ | RESULT float32 3:1x30x16 AFTV MatMul li | RESULT float32 3:1x30x30 BAAF FusedMatMul _o
061 ~ | RESULT float32 3:1x16x30 XEZG Transpose tr | RESULT float32 3:1x30x16 AFTV MatMul li
062 ~ | RESULT float32 3:1x30x30 GAAX MatMul ma | RESULT bool 2:30x30 HLZC Equal eq
063 - | RESULT float32 1:1 AAAA Reshape _r |
064 ~ | RESULT float32 3:1x30x30 BAAF Mul _o | RESULT float32 3:1x30x30 ???? Where ma
065 - | RESULT int64 1:2 AAAA Concat Sl |
066 - | RESULT int64 1:2 EEAA Concat Sl |
067 - | RESULT int64 1:2 ABAA Concat Sl |
068 ~ | RESULT float32 2:30x30 KGSP Slice sl | RESULT float32 3:1x30x30 HGHH Softmax so
069 - | RESULT float32 1:1 AAAA Reshape _r |
070 ~ | RESULT bool 2:30x30 HLZC Equal eq | RESULT float32 3:1x30x16 ZAAZ MatMul ma
071 ~ | RESULT float32 3:1x30x30 ???? Where ma | RESULT float32 3:1x30x32 VYZZ Concat ca
072 ~ | RESULT float32 3:1x30x30 HGHH Softmax so | RESULT float32 2:30x32 VYZZ Reshape Ma
073 ~ | RESULT float32 3:1x30x16 ZAAZ MatMul ma | RESULT float32 2:30x16 VAZA Gemm Ma
074 ~ | RESULT float32 3:1x30x32 VYZZ Concat ca | RESULT float32 3:1x30x16 VAZA Reshape li
075 - | RESULT float32 2:30x32 VYZZ Reshape Ma |
076 ~ | RESULT float32 2:30x16 VAZA Gemm Ma | RESULT float32 3:1x30x16 ZBAA SkipLayerNormal _o
077 ~ | RESULT float32 3:1x30x16 VAZA Reshape li | RESULT float32 3:1x30x1 AABZ SkipLayerNormal un
078 ~ | RESULT float32 3:1x30x16 CORJ Add ad | RESULT float32 3:1x30x1 FGGE SkipLayerNormal un
079 ~ | RESULT float32 3:1x30x16 ZBAA LayerNormalizat _o | RESULT float32 3:1x30x16 CORJ SkipLayerNormal ad
080 = | RESULT float32 2:30x16 ZBAA Reshape Ma | RESULT float32 2:30x16 ZBAA Reshape Ma
081 = | RESULT float32 2:30x128 BASX Gemm Sw | RESULT float32 2:30x128 BASX Gemm Sw
082 = | RESULT float32 2:30x128 ASMS Relu Sw | RESULT float32 2:30x128 ASMS Relu Sw
083 = | RESULT float32 3:1x30x128 ASMS Reshape re | RESULT float32 3:1x30x128 ASMS Reshape re
084 = | RESULT float32 2:30x128 ASMS Reshape Ma | RESULT float32 2:30x128 ASMS Reshape Ma
085 = | RESULT float32 2:30x16 BAAA Gemm Ma | RESULT float32 2:30x16 BAAA Gemm Ma
086 = | RESULT float32 3:1x30x16 BAAA Reshape li | RESULT float32 3:1x30x16 BAAA Reshape li
087 = | RESULT float32 3:1x30x16 ENQI Add ou | RESULT float32 3:1x30x16 ENQI Add ou
088 = | OUTPUT float32 3:1x30x16 ENQI ou | OUTPUT float32 3:1x30x16 ENQI ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 6.516 seconds)
Related examples

to_onnx and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant