Note
Go to the end to download the full example code.
to_onnx and submodules from LLMs¶
Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.
A simple LLM¶
All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.
import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
import torch
from onnxruntime import InferenceSession
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.bench_run import max_diff
from experimental_experiment.xbuilder import OptimizationOptions
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, context_size: int):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
self.register_buffer(name="mask", tensor=torch.tril(input=ones))
def forward(self, x):
B, T, C = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int, ff_dim: int):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)
print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.9085302352905273, max=3.940504550933838
First conversion to ONNX¶
The conversion relies on torch.export.export()
.
which gives:
ep = torch.export.export(llm, (input_ids,))
print(ep.graph)
# Then function :func:`to_onnx <experimental_experiment.torch_interpreter.to_onnx>`
# converts it into ONNX.
onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
graph():
%p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
%p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
%p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
%p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
%p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
%p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
%p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
%p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
%p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
%p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
%p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
%p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
%p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
%p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
%p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
%p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
%p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
%p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
%b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
%b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
%input_ids : [num_users=2] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
%embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
%add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
%layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
%linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
%transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
%matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
%eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
%masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
%softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
%matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
%linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
%linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
%linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
%transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
%matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
%slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
%masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
%softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
%matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
%linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
%add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
%layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
%linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
%linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
return (add_2,)
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='p_embedding_embedding_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_embedding_pe_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_decoder_norm_1_weight' type=dtype('float32') shape=(16,)
init: name='p_decoder_norm_1_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_attention_attention_0_query_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_0_key_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_0_value_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_1_query_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_1_key_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_1_value_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_linear_weight' type=dtype('float32') shape=(16, 32)
init: name='p_decoder_attention_linear_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_norm_2_weight' type=dtype('float32') shape=(16,)
init: name='p_decoder_norm_2_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_feed_forward_linear_1_weight' type=dtype('float32') shape=(128, 16)
init: name='p_decoder_feed_forward_linear_1_bias' type=dtype('float32') shape=(128,)
init: name='p_decoder_feed_forward_linear_2_weight' type=dtype('float32') shape=(16, 128)
init: name='p_decoder_feed_forward_linear_2_bias' type=dtype('float32') shape=(16,)
init: name='b_decoder_attention_attention_0_mask' type=dtype('float32') shape=(256, 256)
init: name='b_decoder_attention_attention_1_mask' type=dtype('float32') shape=(256, 256)
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.25], dtype=float32)
init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
init: name='init7_s1_30' type=dtype('int64') shape=(1,) -- array([30])
init: name='init1_s_2' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-inf], dtype=float32)
init: name='init1_s16_' type=dtype('float32') shape=(16,)
init: name='init1_s16_2' type=dtype('float32') shape=(16,)
Gather(p_embedding_embedding_weight, input_ids) -> embedding
Gather(p_embedding_pe_weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
Mul(init1_s16_, p_decoder_norm_1_weight) -> LayerNormalizationScalePattern_init1_s16_
Mul(p_decoder_norm_1_weight, init1_s16_2) -> LayerNormalizationScalePattern_init1_s16_2
Add(LayerNormalizationScalePattern_init1_s16_2, p_decoder_norm_1_bias) -> LayerNormalizationScalePattern_init1_s16_3
LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
Transpose(p_decoder_attention_attention_0_query_weight, perm=[1,0]) -> _onx_transpose0
MatMul(_onx_add02, _onx_transpose0) -> linear
Transpose(p_decoder_attention_attention_0_key_weight, perm=[1,0]) -> _onx_transpose02
MatMul(_onx_add02, _onx_transpose02) -> linear_1
Transpose(linear_1, perm=[0,2,1]) -> transpose
MatMul(linear, transpose) -> matmul
Transpose(p_decoder_attention_attention_0_value_weight, perm=[1,0]) -> _onx_transpose03
MatMul(_onx_add02, _onx_transpose03) -> linear_2
Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
Mul(matmul, _onx_reshape0) -> _onx_mul02
Slice(b_decoder_attention_attention_0_mask, init7_s1_0, init7_s1_30, init7_s1_0) -> slice_1
Slice(slice_1, init7_s1_0, init7_s1_30, init7_s1_1) -> slice_2
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape02
Equal(slice_2, _onx_reshape02) -> eq
Where(eq, init1_s1_3, _onx_mul02) -> _onx_where0
Softmax(_onx_where0, axis=-1) -> softmax
MatMul(softmax, linear_2) -> matmul_1
Transpose(p_decoder_attention_attention_1_query_weight, perm=[1,0]) -> _onx_transpose04
MatMul(_onx_add02, _onx_transpose04) -> linear_3
Transpose(p_decoder_attention_attention_1_key_weight, perm=[1,0]) -> _onx_transpose05
MatMul(_onx_add02, _onx_transpose05) -> linear_4
Transpose(linear_4, perm=[0,2,1]) -> transpose_1
MatMul(linear_3, transpose_1) -> matmul_2
Transpose(p_decoder_attention_attention_1_value_weight, perm=[1,0]) -> _onx_transpose06
MatMul(_onx_add02, _onx_transpose06) -> linear_5
Reshape(init1_s_, init7_s1_1) -> _onx_reshape03
Mul(matmul_2, _onx_reshape03) -> _onx_mul03
Slice(b_decoder_attention_attention_1_mask, init7_s1_0, init7_s1_30, init7_s1_0) -> slice_3
Slice(slice_3, init7_s1_0, init7_s1_30, init7_s1_1) -> slice_4
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape04
Equal(slice_4, _onx_reshape04) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul03) -> _onx_where02
Softmax(_onx_where02, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
Transpose(p_decoder_attention_linear_weight, perm=[1,0]) -> _onx_transpose07
MatMul(cat, _onx_transpose07) -> _onx_matmul0
Add(_onx_matmul0, p_decoder_attention_linear_bias) -> linear_6
Add(linear_6, add) -> add_1
Mul(init1_s16_, p_decoder_norm_2_weight) -> LayerNormalizationScalePattern_init1_s16_4
Mul(p_decoder_norm_2_weight, init1_s16_2) -> LayerNormalizationScalePattern_init1_s16_5
Add(LayerNormalizationScalePattern_init1_s16_5, p_decoder_norm_2_bias) -> LayerNormalizationScalePattern_init1_s16_6
LayerNormalization(add_1, LayerNormalizationScalePattern_init1_s16_4, LayerNormalizationScalePattern_init1_s16_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
Transpose(p_decoder_feed_forward_linear_1_weight, perm=[1,0]) -> _onx_transpose08
MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul02
Add(_onx_matmul02, p_decoder_feed_forward_linear_1_bias) -> linear_7
Relu(linear_7) -> relu
Transpose(p_decoder_feed_forward_linear_2_weight, perm=[1,0]) -> _onx_transpose09
MatMul(relu, _onx_transpose09) -> _onx_matmul03
Add(_onx_matmul03, p_decoder_feed_forward_linear_2_bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Let’s check there is no discrepancy.
output: shape=(1, 30, 16), min=-3.9085307121276855, max=3.940504312515259
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")
ONNX with submodules¶
Let’s produce an ONNX model with submodules.
Function to_onnx
is calling the function torch.export.unflatten.unflatten()
under the hood. The fx graph looks like the following.
ep = torch.export.export(llm, (input_ids,))
unflatten_ep = torch.export.unflatten(ep)
print(unflatten_ep.graph)
graph():
%input_ids : [num_users=1] = placeholder[target=input_ids]
%embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
return (decoder,)
The exported graph looks simpler and shows something like:
%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
It preserves the hierarchy but it does not necessarily preserves the signatures
of the initial modules. That’s was not one of our goals.
The tricky part is module called (embedding) is not an instance Embedding
but an instance of InterpreterModule
and contains the fx nodes contributing to the submodule and coming from the
previous graph.
Now the ONNX graph.
onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=dtype('float32') shape=(1024, 16)
init: name='embedding.pe.weight' type=dtype('float32') shape=(1024, 16)
init: name='mask' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.0.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.value.weight' type=dtype('float32') shape=(16, 16)
init: name='mask2' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.1.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.value.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.linear.weight' type=dtype('float32') shape=(16, 32)
init: name='decoder.feed_forward.linear_1.weight' type=dtype('float32') shape=(128, 16)
init: name='decoder.feed_forward.linear_1.bias' type=dtype('float32') shape=(128,)
init: name='decoder.feed_forward.linear_2.weight' type=dtype('float32') shape=(16, 128)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
__main__.DecoderLayer[aten_local_function](embedding, mask2, mask, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
input: 'weight'
input: 'bias'
Constant(value=[1.0, 1.0,...) -> init1_s16_
Mul(init1_s16_, weight) -> LayerNormalizationScalePattern_init1_s16_
Constant(value=[0.0, 0.0,...) -> init1_s16_2
Mul(weight, init1_s16_2) -> LayerNormalizationScalePattern_init1_s16_2
Add(LayerNormalizationScalePattern_init1_s16_2, bias) -> LayerNormalizationScalePattern_init1_s16_3
LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'weight'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
MatMul(layer_norm, _onx_transpose0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
Constant(value=0.25) -> init1_s_
Constant(value=[1]) -> init7_s1_1
Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
Constant(value=[0]) -> init7_s1_0
Constant(value=[30]) -> init7_s1_30
Slice(mask, init7_s1_0, init7_s1_30, init7_s1_0) -> slice_1
Slice(slice_1, init7_s1_0, init7_s1_30, init7_s1_1) -> slice_2
Constant(value=0.0) -> init1_s_2
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape02
Equal(slice_2, _onx_reshape02) -> eq
Constant(value=[-inf]) -> init1_s1_
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.query.weight) -> query
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.key.weight) -> key
Transpose(key, perm=[0,2,1]) -> transpose
MatMul(query, transpose) -> matmul
Mul(matmul, _onx_reshape0) -> _onx_mul0
Where(eq, init1_s1_, _onx_mul0) -> _onx_where0
Softmax(_onx_where0, axis=-1) -> softmax
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.value.weight) -> value
MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'cat'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
MatMul(cat, _onx_transpose0) -> _onx_matmul0
Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask2'
input: 'mask'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
Constant(value=[-0.150363...) -> decoder.attention.linear.bias
__main__.AttentionBlock[aten_local_function](layer_norm, mask, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, mask2, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
Linear_2[aten_local_function](cat, decoder.attention.linear.weight, decoder.attention.linear.bias) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
MatMul(layer_norm_1, _onx_transpose0) -> _onx_matmul0
Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.feed_forward.linear_1.bias'
Constant(value=[0.0262538...) -> decoder.feed_forward.linear_2.bias
Linear_3[aten_local_function](layer_norm_1, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> linear_1
ReLU[aten_local_function](linear_1) -> relu
Linear_3[aten_local_function](relu, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_2.bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
input: 'mask2'
input: 'mask'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
input: 'decoder.feed_forward.linear_1.bias'
Constant(value=[1.0, 1.0,...) -> decoder.norm_1.weight
Constant(value=[0.0, 0.0,...) -> decoder.norm_1.bias
LayerNorm[aten_local_function](add, decoder.norm_1.weight, decoder.norm_1.bias) -> norm_1
__main__.MultiAttentionBlock[aten_local_function](norm_1, mask2, mask, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention
Add(attention, add) -> add_1
Constant(value=[1.0, 1.0,...) -> decoder.norm_2.weight
Constant(value=[0.0, 0.0,...) -> decoder.norm_2.bias
LayerNorm[aten_local_function](add_1, decoder.norm_2.weight, decoder.norm_2.bias) -> norm_2
__main__.FeedForward[aten_local_function](norm_2, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> feed_forward
Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?
We check again there is no new discrepancies.
output: shape=(1, 30, 16), min=-3.9085307121276855, max=3.940504312515259
max discrepancy=4.76837158203125e-07
Let’s save the ONNX model.
onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")
And visually.
<Axes: >
Inlining¶
The ONNX graph can still be inline after this.
opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=dtype('float32') shape=(1024, 16)
init: name='embedding.pe.weight' type=dtype('float32') shape=(1024, 16)
init: name='mask' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.0.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.value.weight' type=dtype('float32') shape=(16, 16)
init: name='mask2' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.1.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.value.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.linear.weight' type=dtype('float32') shape=(16, 32)
init: name='decoder.feed_forward.linear_1.weight' type=dtype('float32') shape=(128, 16)
init: name='decoder.feed_forward.linear_1.bias' type=dtype('float32') shape=(128,)
init: name='decoder.feed_forward.linear_2.weight' type=dtype('float32') shape=(16, 128)
Constant(value=[1]) -> init7_s1_1__11
Gather(embedding.embedding.weight, input_ids) -> embedding__1
Gather(embedding.pe.weight, input_ids) -> pe__1
Add(embedding__1, pe__1) -> embedding
Constant(value=[1.0, 1.0,...) -> decoder.norm_1.weight__4
Constant(value=[0.0, 0.0,...) -> decoder.norm_1.bias__4
Constant(value=[1.0, 1.0,...) -> decoder.norm_2.weight__4
Constant(value=[0.0, 0.0,...) -> decoder.norm_2.bias__4
Constant(value=[1.0, 1.0,...) -> init1_s16___5
Mul(init1_s16___5, decoder.norm_1.weight__4) -> LayerNormalizationScalePattern_init1_s16___5
Constant(value=[0.0, 0.0,...) -> init1_s16_2__5
Mul(decoder.norm_1.weight__4, init1_s16_2__5) -> LayerNormalizationScalePattern_init1_s16_2__5
Add(LayerNormalizationScalePattern_init1_s16_2__5, decoder.norm_1.bias__4) -> LayerNormalizationScalePattern_init1_s16_3__5
LayerNormalization(embedding, LayerNormalizationScalePattern_init1_s16___5, LayerNormalizationScalePattern_init1_s16_3__5, axis=-1, epsilon=0.00, stash_type=1) -> norm_1__4
Constant(value=[-0.150363...) -> decoder.attention.linear.bias__6
Constant(value=0.25) -> init1_s___7
Constant(value=[1]) -> init7_s1_1__7
Reshape(init1_s___7, init7_s1_1__7) -> _onx_reshape0__7
Constant(value=[0]) -> init7_s1_0__7
Constant(value=[30]) -> init7_s1_30__7
Slice(mask, init7_s1_0__7, init7_s1_30__7, init7_s1_0__7) -> slice_1__7
Slice(slice_1__7, init7_s1_0__7, init7_s1_30__7, init7_s1_1__7) -> slice_2__7
Constant(value=0.0) -> init1_s_2__7
Reshape(init1_s_2__7, init7_s1_1__7) -> _onx_reshape02__7
Equal(slice_2__7, _onx_reshape02__7) -> eq__7_2
Constant(value=[-inf]) -> init1_s1___7
Transpose(decoder.attention.attention.0.query.weight, perm=[1,0]) -> _onx_transpose0__8
MatMul(norm_1__4, _onx_transpose0__8) -> query__7
Transpose(decoder.attention.attention.0.key.weight, perm=[1,0]) -> _onx_transpose0__9
MatMul(norm_1__4, _onx_transpose0__9) -> key__7
Transpose(key__7, perm=[0,2,1]) -> transpose__7_1
MatMul(query__7, transpose__7_1) -> matmul__7
Mul(matmul__7, _onx_reshape0__7) -> _onx_mul0__7
Where(eq__7_2, init1_s1___7, _onx_mul0__7) -> _onx_where0__7
Softmax(_onx_where0__7, axis=-1) -> softmax__7
Transpose(decoder.attention.attention.0.value.weight, perm=[1,0]) -> _onx_transpose0__10
MatMul(norm_1__4, _onx_transpose0__10) -> value__7
MatMul(softmax__7, value__7) -> attention_0__6
Constant(value=0.25) -> init1_s___11
Reshape(init1_s___11, init7_s1_1__11) -> _onx_reshape0__11
Constant(value=[0]) -> init7_s1_0__11
Constant(value=[30]) -> init7_s1_30__11
Slice(mask2, init7_s1_0__11, init7_s1_30__11, init7_s1_0__11) -> slice_1__11
Slice(slice_1__11, init7_s1_0__11, init7_s1_30__11, init7_s1_1__11) -> slice_2__11
Constant(value=0.0) -> init1_s_2__11
Reshape(init1_s_2__11, init7_s1_1__11) -> _onx_reshape02__11
Equal(slice_2__11, _onx_reshape02__11) -> eq__11_4
Constant(value=[-inf]) -> init1_s1___11
Transpose(decoder.attention.attention.1.query.weight, perm=[1,0]) -> _onx_transpose0__12
MatMul(norm_1__4, _onx_transpose0__12) -> query__11
Transpose(decoder.attention.attention.1.key.weight, perm=[1,0]) -> _onx_transpose0__13
MatMul(norm_1__4, _onx_transpose0__13) -> key__11
Transpose(key__11, perm=[0,2,1]) -> transpose__11_3
MatMul(query__11, transpose__11_3) -> matmul__11
Mul(matmul__11, _onx_reshape0__11) -> _onx_mul0__11
Where(eq__11_4, init1_s1___11, _onx_mul0__11) -> _onx_where0__11
Softmax(_onx_where0__11, axis=-1) -> softmax__11
Transpose(decoder.attention.attention.1.value.weight, perm=[1,0]) -> _onx_transpose0__14
MatMul(norm_1__4, _onx_transpose0__14) -> value__11
MatMul(softmax__11, value__11) -> attention_1__6
Concat(attention_0__6, attention_1__6, axis=-1) -> cat__6_0
Transpose(decoder.attention.linear.weight, perm=[1,0]) -> _onx_transpose0__15
MatMul(cat__6_0, _onx_transpose0__15) -> _onx_matmul0__15
Add(_onx_matmul0__15, decoder.attention.linear.bias__6) -> attention__4
Add(attention__4, embedding) -> add_1__4
Constant(value=[1.0, 1.0,...) -> init1_s16___16
Mul(init1_s16___16, decoder.norm_2.weight__4) -> LayerNormalizationScalePattern_init1_s16___16
Constant(value=[0.0, 0.0,...) -> init1_s16_2__16
Mul(decoder.norm_2.weight__4, init1_s16_2__16) -> LayerNormalizationScalePattern_init1_s16_2__16
Add(LayerNormalizationScalePattern_init1_s16_2__16, decoder.norm_2.bias__4) -> LayerNormalizationScalePattern_init1_s16_3__16
LayerNormalization(add_1__4, LayerNormalizationScalePattern_init1_s16___16, LayerNormalizationScalePattern_init1_s16_3__16, axis=-1, epsilon=0.00, stash_type=1) -> norm_2__4
Constant(value=[0.0262538...) -> decoder.feed_forward.linear_2.bias__17
Transpose(decoder.feed_forward.linear_1.weight, perm=[1,0]) -> _onx_transpose0__18
MatMul(norm_2__4, _onx_transpose0__18) -> _onx_matmul0__18
Add(_onx_matmul0__18, decoder.feed_forward.linear_1.bias) -> linear_1__17
Relu(linear_1__17) -> relu__17
Transpose(decoder.feed_forward.linear_2.weight, perm=[1,0]) -> _onx_transpose0__20
MatMul(relu__17, _onx_transpose0__20) -> _onx_matmul0__20
Add(_onx_matmul0__20, decoder.feed_forward.linear_2.bias__17) -> feed_forward__4
Add(feed_forward__4, add_1__4) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Optimizations¶
The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.
onx_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2
),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder.optimize] start with 75 nodes
[GraphBuilder.optimize] #patterns=51
[GraphBuilder.remove_unused] 4/46remove_initializer:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 5/46remove_initializer:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 6/46remove_initializer:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 7/46remove_initializer:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 8/46remove_initializer:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 9/46remove_initializer:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 10/46remove_initializer:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder.remove_unused] 14/46remove_initializer:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder.remove_unused] 16/46remove_initializer:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder.remove_unused] 18/46remove_initializer:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 19/46remove_initializer:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 23/46remove_initializer:init1_s_:float32[()]
[GraphBuilder.remove_unused] 24/46remove_initializer:init7_s1_1:int64[(1,)]
[GraphBuilder.remove_unused] 25/46remove_initializer:init7_s1_0:int64[(1,)]
[GraphBuilder.remove_unused] 26/46remove_initializer:init7_s1_30:int64[(1,)]
[GraphBuilder.remove_unused] 27/46remove_initializer:init1_s_2:float32[()]
[GraphBuilder.remove_unused] 33/46remove_initializer:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder.remove_unused] 40/46remove_initializer:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization.optimize] start with 53 nodes, 28 initializers, 51 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization.optimize] use pattern 1/51 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 2/51 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 3/51 - P0 - CastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 4/51 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 5/51 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 6/51 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 7/51 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 8/51 - P0 - GeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 9/51 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 10/51 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 11/51 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 12/51 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 13/51 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 14/51 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 15/51 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 16/51 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 17/51 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 18/51 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 19/51 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 20/51 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 21/51 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 22/51 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 23/51 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 24/51 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 25/51 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 26/51 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 27/51 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 28/51 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 29/51 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 30/51 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 31/51 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern 32/51 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 33/51 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 34/51 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 35/51 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 36/51 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern 37/51 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 38/51 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 39/51 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 40/51 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 41/51 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 42/51 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 43/51 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 44/51 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 45/51 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 46/51 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 47/51 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 48/51 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 49/51 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 50/51 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 51/51 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*CastPattern - time=0.014 | max_time=SoftmaxCrossEntropyLossCastPattern:0.004
[GraphBuilderPatternOptimization.optimize] iteration 1: 51 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] increase priority to 1
[GraphBuilderPatternOptimization.optimize] iteration 2: 51 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.009 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 3: 39 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationScalePattern - time=0.008 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 4: 41 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] increase priority to 2
[GraphBuilderPatternOptimization.optimize] iteration 5: 41 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.006 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 6: 37 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] increase priority to 3
[GraphBuilderPatternOptimization.optimize] iteration 7: 37 nodes, priority=3
[GraphBuilderPatternOptimization.optimize] done after 8 iterations with 37 nodes in 0.076
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002351299990550615
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0014661570021416992
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.001612208998267306
STAT apply_LayerNormalizationScalePattern +8 -6 #it=1 maxmatch=1 i=2 - time=0.001656857999478234
STAT build_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0034324219996051397
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00028798499988624826
STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001610279992746655
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.000420615997427376
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005751800017606001
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005249440000625327
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004153130066697486
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0012884470052085817
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005035309950471856
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0010113780044775922
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0005228279951552395
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0007516890036640689
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00047723000170663
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002968260014313273
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000539689001016086
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004626949958037585
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00037213200630503707
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044710899601341225
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0003920269991795067
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0008664269953442272
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.403199950000271e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.986199878156185e-05
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.004822103001060896
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0055566909941262566
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=2.0358995243441314e-05
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031650299933971837
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.008763080000790069
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.000779076995968353
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0005384139985835645
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.005380797992984299
STAT match_MatMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007936110050650313
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0013552689924836159
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008081600026343949
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004969670007994864
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004736139999295119
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033958600397454575
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.001289054001972545
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000834900994959753
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0005168720017536543
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007207430062408093
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004518020068644546
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005238180019659922
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0017317149977316149
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006951350005692802
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043415199252194725
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.007078238002577564
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004059839993715286
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006120830003055744
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.001375295003526844
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008386039953620639
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000900467002793448
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0006130619985924568
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0006817969988333061
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007598710035381373
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0005397129934863187
STAT remove_identity_nodes +2 -4 #it=3 maxmatch=0 i=0 - time=0.0008773940026003402
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--
INPUT: 1 x 7t
OUTPUT: 1 x 1t
INIT: 29 x 1t
INIT: 1 x 7t
NODE: 8 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 2 x LayerNormalization
NODE: 11 x MatMul
NODE: 4 x Mul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
NODE: 2 x Add -SIG- 1t[16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 3 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 2 x LayerNormalization -SIG- 1t[1x30x16], 1t[16], 1t[16]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 4 x Mul -SIG- 1t[16], 1t[16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
[GraphBuilder.remove_unused] 9/30remove_initializer:init7_s1_-1:int64[(1,)]
[GraphBuilder.remove_unused] 10/30remove_initializer:init1_s1_:float32[(1,)]
[GraphBuilder.remove_unused] 11/30remove_initializer:init1_s1_2:float32[(1,)]
[GraphBuilder.remove_unused] 16/30remove_initializer:_onx_reshape0:float32[(1,)]
[GraphBuilder.remove_unused] 22/30remove_initializer:_onx_reshape03:float32[(1,)]
[GraphBuilder.remove_unused] 2/31remove_initializer:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 3/31remove_initializer:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 5/31remove_initializer:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 6/31remove_initializer:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 23/31remove_initializer:init1_s16_:float32[(16,)]
[GraphBuilder.remove_unused] 24/31remove_initializer:init1_s16_2:float32[(16,)]
[GraphBuilder.remove_unused] 26/31remove_initializer:LayerNormalizationScalePattern_init1_s16_2:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 29/31remove_initializer:LayerNormalizationScalePattern_init1_s16_5:torch.float32[torch.Size([16])]
[GraphBuilder.optimize] done with 31 nodes in 0.094
opset: domain='' version=18
opset: domain='com.microsoft' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='p_embedding_embedding_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_embedding_pe_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_decoder_attention_linear_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_feed_forward_linear_1_bias' type=dtype('float32') shape=(128,)
init: name='p_decoder_feed_forward_linear_2_bias' type=dtype('float32') shape=(16,)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-inf], dtype=float32)
init: name='_onx_transpose0' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose02' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose03' type=dtype('float32') shape=(16, 16)
init: name='slice_2' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape02' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose04' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose05' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose06' type=dtype('float32') shape=(16, 16)
init: name='slice_4' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape04' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose07' type=dtype('float32') shape=(32, 16)
init: name='_onx_transpose08' type=dtype('float32') shape=(16, 128)
init: name='_onx_transpose09' type=dtype('float32') shape=(128, 16)
init: name='LayerNormalizationScalePattern_init1_s16_' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_3' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_4' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_6' type=dtype('float32') shape=(16,)
Equal(slice_2, _onx_reshape02) -> eq
Gather(p_embedding_embedding_weight, input_ids) -> embedding
Gather(p_embedding_pe_weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
MatMul(_onx_add02, _onx_transpose0) -> linear
MatMul(_onx_add02, _onx_transpose02) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul02
Where(eq, init1_s1_3, _onx_mul02) -> _onx_where0
Softmax(_onx_where0, axis=-1) -> softmax
MatMul(_onx_add02, _onx_transpose03) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_add02, _onx_transpose04) -> linear_3
MatMul(_onx_add02, _onx_transpose05) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul03
MatMul(_onx_add02, _onx_transpose06) -> linear_5
Equal(slice_4, _onx_reshape04) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul03) -> _onx_where02
Softmax(_onx_where02, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, _onx_transpose07) -> _onx_matmul0
Add(_onx_matmul0, p_decoder_attention_linear_bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, LayerNormalizationScalePattern_init1_s16_4, LayerNormalizationScalePattern_init1_s16_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul02
Add(_onx_matmul02, p_decoder_feed_forward_linear_1_bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, _onx_transpose09) -> _onx_matmul03
Add(_onx_matmul03, p_decoder_feed_forward_linear_2_bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
This shows a kernel FusedMatMul[com.microsoft]
which implement a kernel equivalent Gemm
but working for any tensors, not only 2D.
How does it work on the model which keeps exports the moduels as local functions?
The optimizer optimizes every local function independantly.
We reduce the verbosity…
onx_module_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=dtype('float32') shape=(1024, 16)
init: name='embedding.pe.weight' type=dtype('float32') shape=(1024, 16)
init: name='_onx_transpose0' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose02' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose03' type=dtype('float32') shape=(16, 16)
init: name='slice_2' type=dtype('float32') shape=(30, 30)
init: name='_onx_transpose04' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose022' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose032' type=dtype('float32') shape=(16, 16)
init: name='slice_4' type=dtype('float32') shape=(30, 30)
init: name='_onx_transpose05' type=dtype('float32') shape=(32, 16)
init: name='decoder.feed_forward.linear_1.bias' type=dtype('float32') shape=(128,)
init: name='_onx_transpose06' type=dtype('float32') shape=(16, 128)
init: name='_onx_transpose023' type=dtype('float32') shape=(128, 16)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
__main__.DecoderLayer[aten_local_function](embedding, _onx_transpose06, _onx_transpose023, slice_4, slice_2, _onx_transpose05, _onx_transpose04, _onx_transpose032, _onx_transpose03, _onx_transpose022, _onx_transpose02, _onx_transpose0, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
Constant(value=[1.0, 1.0,...) -> LayerNormalizationScalePattern_init1_s16_
Constant(value=[0.0, 0.0,...) -> LayerNormalizationScalePattern_init1_s16_3
LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: '_onx_transpose0'
MatMul(layer_norm, _onx_transpose0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_2'
input: '_onx_transpose03'
input: '_onx_transpose02'
input: '_onx_transpose0'
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.0]) -> _onx_reshape02
Equal(slice_2, _onx_reshape02) -> eq
Linear[aten_local_function](layer_norm, _onx_transpose0) -> query
Linear[aten_local_function](layer_norm, _onx_transpose02) -> key
FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul0
Where(eq, init1_s1_, _onx_mul0) -> _onx_where0
Softmax(_onx_where0, axis=-1) -> softmax
Linear[aten_local_function](layer_norm, _onx_transpose03) -> value
MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'cat'
input: '_onx_transpose0'
input: 'bias'
MatMul(cat, _onx_transpose0) -> _onx_matmul0
Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_4'
input: 'slice_2'
input: '_onx_transpose05'
input: '_onx_transpose04'
input: '_onx_transpose032'
input: '_onx_transpose03'
input: '_onx_transpose022'
input: '_onx_transpose02'
input: '_onx_transpose0'
Constant(value=[-0.150363...) -> decoder.attention.linear.bias
__main__.AttentionBlock[aten_local_function](layer_norm, slice_2, _onx_transpose03, _onx_transpose02, _onx_transpose0) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, slice_4, _onx_transpose032, _onx_transpose022, _onx_transpose04) -> attention_1
Concat(attention_0, attention_1, axis=-1) -> cat
Linear_2[aten_local_function](cat, _onx_transpose05, decoder.attention.linear.bias) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: '_onx_transpose0'
input: 'bias'
MatMul(layer_norm_1, _onx_transpose0) -> _onx_matmul0
Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: '_onx_transpose02'
input: '_onx_transpose0'
input: 'decoder.feed_forward.linear_1.bias'
Constant(value=[0.0262538...) -> decoder.feed_forward.linear_2.bias
Linear_3[aten_local_function](layer_norm_1, _onx_transpose0, decoder.feed_forward.linear_1.bias) -> linear_1
ReLU[aten_local_function](linear_1) -> relu
Linear_3[aten_local_function](relu, _onx_transpose02, decoder.feed_forward.linear_2.bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'add'
input: '_onx_transpose06'
input: '_onx_transpose023'
input: 'slice_4'
input: 'slice_2'
input: '_onx_transpose05'
input: '_onx_transpose04'
input: '_onx_transpose032'
input: '_onx_transpose03'
input: '_onx_transpose022'
input: '_onx_transpose02'
input: '_onx_transpose0'
input: 'decoder.feed_forward.linear_1.bias'
LayerNorm[aten_local_function](add) -> norm_1
__main__.MultiAttentionBlock[aten_local_function](norm_1, slice_4, slice_2, _onx_transpose05, _onx_transpose04, _onx_transpose032, _onx_transpose03, _onx_transpose022, _onx_transpose02, _onx_transpose0) -> attention
Add(attention, add) -> add_1
LayerNorm[aten_local_function](add_1) -> norm_2
__main__.FeedForward[aten_local_function](norm_2, _onx_transpose023, _onx_transpose06, decoder.feed_forward.linear_1.bias) -> feed_forward
Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?
It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.
Optimizations for CUDA¶
The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.
onx_cuda_optimized = to_onnx(
llm,
(input_ids,),
options=OptimizationOptions(
patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder.optimize] start with 75 nodes
[GraphBuilder.optimize] #patterns=51
[GraphBuilder.remove_unused] 4/46remove_initializer:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 5/46remove_initializer:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 6/46remove_initializer:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 7/46remove_initializer:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 8/46remove_initializer:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 9/46remove_initializer:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 10/46remove_initializer:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder.remove_unused] 14/46remove_initializer:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder.remove_unused] 16/46remove_initializer:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder.remove_unused] 18/46remove_initializer:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 19/46remove_initializer:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 23/46remove_initializer:init1_s_:float32[()]
[GraphBuilder.remove_unused] 24/46remove_initializer:init7_s1_1:int64[(1,)]
[GraphBuilder.remove_unused] 25/46remove_initializer:init7_s1_0:int64[(1,)]
[GraphBuilder.remove_unused] 26/46remove_initializer:init7_s1_30:int64[(1,)]
[GraphBuilder.remove_unused] 27/46remove_initializer:init1_s_2:float32[()]
[GraphBuilder.remove_unused] 33/46remove_initializer:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder.remove_unused] 40/46remove_initializer:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization.optimize] start with 53 nodes, 28 initializers, 51 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization.optimize] use pattern 1/51 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 2/51 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 3/51 - P0 - CastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 4/51 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 5/51 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 6/51 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 7/51 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 8/51 - P0 - GeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 9/51 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 10/51 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 11/51 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 12/51 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 13/51 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 14/51 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 15/51 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 16/51 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 17/51 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 18/51 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 19/51 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 20/51 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 21/51 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 22/51 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 23/51 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 24/51 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 25/51 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 26/51 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 27/51 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 28/51 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 29/51 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 30/51 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 31/51 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern 32/51 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 33/51 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 34/51 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 35/51 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 36/51 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern 37/51 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 38/51 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 39/51 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 40/51 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 41/51 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 42/51 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 43/51 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 44/51 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 45/51 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 46/51 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 47/51 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 48/51 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 49/51 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern 50/51 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern 51/51 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*CastPattern - time=0.013 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization.optimize] iteration 1: 51 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] increase priority to 1
[GraphBuilderPatternOptimization.optimize] iteration 2: 51 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 3: 39 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationScalePattern - time=0.004 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 4: 41 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] increase priority to 2
[GraphBuilderPatternOptimization.optimize] iteration 5: 41 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.007 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 6: 37 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] increase priority to 3
[GraphBuilderPatternOptimization.optimize] iteration 7: 37 nodes, priority=3
[GraphBuilderPatternOptimization.optimize] done after 8 iterations with 37 nodes in 0.059
STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00033199400058947504
STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.00044636500024353154
STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004514319989539217
STAT apply_LayerNormalizationScalePattern +8 -6 #it=1 maxmatch=1 i=2 - time=0.0007248299989441875
STAT build_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0022589520012843423
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00031031700200401247
STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0009253140015061945
STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.00036864899811916985
STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0004700450044765603
STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.000425509009801317
STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002755290006462019
STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000751970994315343
STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00033706999602145515
STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0007624669997312594
STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00047389600513270125
STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000455398992926348
STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00040974299918161705
STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021438399926410057
STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023231499289977364
STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00040036199789028615
STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00022855099814478308
STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030827800583210774
STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0003567460007616319
STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0007394730018859264
STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.263699858915061e-05
STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.395899981725961e-05
STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0036905640008626506
STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.004693301001680084
STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=1.3928998669143766e-05
STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035539600139600225
STAT match_IdentityPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.007192787001258694
STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00045746899559162557
STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.000513541996042477
STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.004610444000718417
STAT match_MatMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006537789995491039
STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0011999769958492834
STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000663203994918149
STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005202160064072814
STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004569890006678179
STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030697799593326636
STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0012463210005080327
STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006260909940465353
STAT match_ReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0006521489995066077
STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006813899999542627
STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004642429994419217
STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005402189999585971
STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0013755860018136445
STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000347028995747678
STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003770200055441819
STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.006253109993849648
STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026055299895233475
STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003796259989030659
STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0010040179986390285
STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009082119977392722
STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006176680035423487
STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0005275410003378056
STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004835229992750101
STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004754610054078512
STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00039690500125288963
STAT remove_identity_nodes +2 -4 #it=3 maxmatch=0 i=0 - time=0.0007931020008982159
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--
INPUT: 1 x 7t
OUTPUT: 1 x 1t
INIT: 29 x 1t
INIT: 1 x 7t
NODE: 8 x Add
NODE: 1 x Concat
NODE: 2 x Equal
NODE: 2 x Gather
NODE: 2 x LayerNormalization
NODE: 11 x MatMul
NODE: 4 x Mul
NODE: 1 x Relu
NODE: 2 x Softmax
NODE: 2 x Where
NODE: 2 x com.microsoft.FusedMatMul
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
INPUT: 1 x 7t[1x30]
OUTPUT: 1 x 1t[1x30x16]
INIT: 2 x 1t[1024x16]
INIT: 1 x 1t[128]
INIT: 1 x 1t[128x16]
INIT: 8 x 1t[16]
INIT: 1 x 1t[16x128]
INIT: 6 x 1t[16x16]
INIT: 7 x 1t[1]
INIT: 2 x 1t[30x30]
INIT: 1 x 1t[32x16]
INIT: 1 x 7t[1]
NODE: 2 x Add -SIG- 1t[16], 1t[16]
NODE: 1 x Add -SIG- 1t[1x30x128], 1t[128]
NODE: 2 x Add -SIG- 1t[1x30x16], 1t[16]
NODE: 3 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
NODE: 2 x Equal -SIG- 1t[30x30], 1t[1]
NODE: 2 x Gather -SIG- 1t[1024x16], 7t[1x30]
NODE: 2 x LayerNormalization -SIG- 1t[1x30x16], 1t[16], 1t[16]
NODE: 1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
NODE: 1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
NODE: 6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
NODE: 2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
NODE: 1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
NODE: 4 x Mul -SIG- 1t[16], 1t[16]
NODE: 1 x Relu -SIG- 1t[1x30x128]
NODE: 2 x Softmax -SIG- 1t[1x30x30]
NODE: 2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
NODE: 2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
[GraphBuilder.remove_unused] 9/30remove_initializer:init7_s1_-1:int64[(1,)]
[GraphBuilder.remove_unused] 10/30remove_initializer:init1_s1_:float32[(1,)]
[GraphBuilder.remove_unused] 11/30remove_initializer:init1_s1_2:float32[(1,)]
[GraphBuilder.remove_unused] 16/30remove_initializer:_onx_reshape0:float32[(1,)]
[GraphBuilder.remove_unused] 22/30remove_initializer:_onx_reshape03:float32[(1,)]
[GraphBuilder.remove_unused] 2/31remove_initializer:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 3/31remove_initializer:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 5/31remove_initializer:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 6/31remove_initializer:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 23/31remove_initializer:init1_s16_:float32[(16,)]
[GraphBuilder.remove_unused] 24/31remove_initializer:init1_s16_2:float32[(16,)]
[GraphBuilder.remove_unused] 26/31remove_initializer:LayerNormalizationScalePattern_init1_s16_2:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 29/31remove_initializer:LayerNormalizationScalePattern_init1_s16_5:torch.float32[torch.Size([16])]
[GraphBuilder.optimize] done with 31 nodes in 0.074
opset: domain='' version=18
opset: domain='com.microsoft' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='p_embedding_embedding_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_embedding_pe_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_decoder_attention_linear_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_feed_forward_linear_1_bias' type=dtype('float32') shape=(128,)
init: name='p_decoder_feed_forward_linear_2_bias' type=dtype('float32') shape=(16,)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-inf], dtype=float32)
init: name='_onx_transpose0' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose02' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose03' type=dtype('float32') shape=(16, 16)
init: name='slice_2' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape02' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose04' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose05' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose06' type=dtype('float32') shape=(16, 16)
init: name='slice_4' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape04' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose07' type=dtype('float32') shape=(32, 16)
init: name='_onx_transpose08' type=dtype('float32') shape=(16, 128)
init: name='_onx_transpose09' type=dtype('float32') shape=(128, 16)
init: name='LayerNormalizationScalePattern_init1_s16_' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_3' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_4' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_6' type=dtype('float32') shape=(16,)
Equal(slice_2, _onx_reshape02) -> eq
Gather(p_embedding_embedding_weight, input_ids) -> embedding
Gather(p_embedding_pe_weight, input_ids) -> embedding_1
Add(embedding, embedding_1) -> add
LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
MatMul(_onx_add02, _onx_transpose0) -> linear
MatMul(_onx_add02, _onx_transpose02) -> linear_1
FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul02
Where(eq, init1_s1_3, _onx_mul02) -> _onx_where0
Softmax(_onx_where0, axis=-1) -> softmax
MatMul(_onx_add02, _onx_transpose03) -> linear_2
MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_add02, _onx_transpose04) -> linear_3
MatMul(_onx_add02, _onx_transpose05) -> linear_4
FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul03
MatMul(_onx_add02, _onx_transpose06) -> linear_5
Equal(slice_4, _onx_reshape04) -> eq_1
Where(eq_1, init1_s1_3, _onx_mul03) -> _onx_where02
Softmax(_onx_where02, axis=-1) -> softmax_1
MatMul(softmax_1, linear_5) -> matmul_3
Concat(matmul_1, matmul_3, axis=-1) -> cat
MatMul(cat, _onx_transpose07) -> _onx_matmul0
Add(_onx_matmul0, p_decoder_attention_linear_bias) -> linear_6
Add(linear_6, add) -> add_1
LayerNormalization(add_1, LayerNormalizationScalePattern_init1_s16_4, LayerNormalizationScalePattern_init1_s16_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul02
Add(_onx_matmul02, p_decoder_feed_forward_linear_1_bias) -> linear_7
Relu(linear_7) -> relu
MatMul(relu, _onx_transpose09) -> _onx_matmul03
Add(_onx_matmul03, p_decoder_feed_forward_linear_2_bias) -> linear_8
Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
Comparison optimized and not optimized?¶
The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 88 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 88 results (first model)
[compare_onnx_execution] got 56 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 90 pairs
[compare_onnx_execution] done
------------
001 = | INITIA float32 2:1024x16 BORW p_ | INITIA float32 2:1024x16 BORW p_
002 = | INITIA float32 2:1024x16 UUVW p_ | INITIA float32 2:1024x16 UUVW p_
003 ~ | INITIA float32 1:16 EEEE p_ | INITIA float32 1:16 AAAA p_
004 ~ | INITIA float32 1:16 AAAA p_ | INITIA float32 1:128 AAAA p_
005 ~ | INITIA float32 2:16x16 AAAA p_ | INITIA float32 1:16 AAAA p_
006 ~ | INITIA float32 2:16x16 AAAA p_ | INITIA float32 1:1 ?AAA in
007 - | INITIA float32 2:16x16 YAAC p_ |
008 ~ | INITIA float32 2:16x16 AAZB p_ | INITIA float32 2:16x16 CAZA _o
009 ~ | INITIA float32 2:16x16 AABA p_ | INITIA float32 2:16x16 AAAA _o
010 ~ | INITIA float32 2:16x16 ZZAA p_ | INITIA float32 2:16x16 AAAA _o
011 ~ | INITIA float32 2:16x32 AAYA p_ | INITIA float32 2:30x30 KGSP sl
012 ~ | INITIA float32 1:16 AAAA p_ | INITIA float32 1:1 AAAA _o
013 ~ | INITIA float32 1:16 EEEE p_ | INITIA float32 2:16x16 ZAAA _o
014 ~ | INITIA float32 1:16 AAAA p_ | INITIA float32 2:16x16 AAAA _o
015 - | INITIA float32 2:128x16 BAAX p_ |
016 ~ | INITIA float32 1:128 AAAA p_ | INITIA float32 2:16x16 AAZY _o
017 - | INITIA float32 2:16x128 AZAA p_ |
018 ~ | INITIA float32 1:16 AAAA p_ | INITIA float32 2:30x30 KGSP sl
019 - | INITIA float32 2:256x256 AOCQ b_ |
020 - | INITIA float32 2:256x256 AOCQ b_ |
021 - | INITIA float32 AAAA in |
022 ~ | INITIA int64 1:1 BAAA in | INITIA float32 1:1 AAAA _o
023 ~ | INITIA int64 1:1 AAAA in | INITIA float32 2:32x16 BAAA _o
024 + | | INITIA float32 2:16x128 AVEA _o
025 + | | INITIA float32 2:128x16 BZAA _o
026 ~ | INITIA int64 1:1 EAAA in | INITIA float32 1:16 EEEE La
027 - | INITIA float32 AAAA in |
028 ~ | INITIA float32 1:1 ?AAA in | INITIA float32 1:16 AAAA La
029 = | INITIA float32 1:16 EEEE in | INITIA float32 1:16 EEEE La
030 = | INITIA float32 1:16 AAAA in | INITIA float32 1:16 AAAA La
031 = | INPUT int64 2:1x30 COAD in | INPUT int64 2:1x30 COAD in
032 = | RESULT float32 3:1x30x16 FDNV Gather em | RESULT float32 3:1x30x16 FDNV Gather em
033 = | RESULT float32 3:1x30x16 QUDH Gather em | RESULT float32 3:1x30x16 QUDH Gather em
034 = | RESULT float32 3:1x30x16 WYQC Add ad | RESULT float32 3:1x30x16 WYQC Add ad
035 - | RESULT float32 1:16 EEEE Mul La |
036 - | RESULT float32 1:16 AAAA Mul La |
037 - | RESULT float32 1:16 AAAA Add La |
038 = | RESULT float32 3:1x30x16 ZBBZ LayerNormalizat _o | RESULT float32 3:1x30x16 ZBBZ LayerNormalizat _o
039 - | RESULT float32 2:16x16 CAZA Transpose _o |
040 = | RESULT float32 3:1x30x16 MQST MatMul li | RESULT float32 3:1x30x16 MQST MatMul li
041 - | RESULT float32 2:16x16 AAAA Transpose _o |
042 = | RESULT float32 3:1x30x16 VXOC MatMul li | RESULT float32 3:1x30x16 VXOC MatMul li
043 - | RESULT float32 2:16x16 AAAA Transpose _o |
044 ~ | RESULT float32 3:1x30x16 VBUE MatMul li | RESULT float32 3:1x30x30 YCFE FusedMatMul _o
045 ~ | RESULT float32 3:1x16x30 WAUT Transpose tr | RESULT float32 3:1x30x16 VBUE MatMul li
046 - | RESULT float32 3:1x30x30 QLXR MatMul ma |
047 - | RESULT float32 1:1 AAAA Reshape _o |
048 - | RESULT float32 3:1x30x30 YCFE Mul _o |
049 - | RESULT float32 2:30x256 KGAH Slice sl |
050 - | RESULT float32 2:30x30 KGSP Slice sl |
051 - | RESULT float32 1:1 AAAA Reshape _o |
052 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
053 = | RESULT float32 3:1x30x30 ???? Where _o | RESULT float32 3:1x30x30 ???? Where _o
054 = | RESULT float32 3:1x30x30 HGHH Softmax so | RESULT float32 3:1x30x30 HGHH Softmax so
055 = | RESULT float32 3:1x30x16 DYYY MatMul ma | RESULT float32 3:1x30x16 DYYY MatMul ma
056 - | RESULT float32 2:16x16 ZAAA Transpose _o |
057 = | RESULT float32 3:1x30x16 BLRA MatMul li | RESULT float32 3:1x30x16 BLRA MatMul li
058 - | RESULT float32 2:16x16 AAAA Transpose _o |
059 = | RESULT float32 3:1x30x16 ZQZD MatMul li | RESULT float32 3:1x30x16 ZQZD MatMul li
060 - | RESULT float32 2:16x16 AAZY Transpose _o |
061 - | RESULT float32 3:1x30x16 ZBFC MatMul li |
062 - | RESULT float32 3:1x16x30 PBAC Transpose tr |
063 ~ | RESULT float32 3:1x30x30 KHTN MatMul ma | RESULT float32 3:1x30x30 CWZX FusedMatMul _o
064 - | RESULT float32 1:1 AAAA Reshape _o |
065 ~ | RESULT float32 3:1x30x30 CWZX Mul _o | RESULT float32 3:1x30x16 ZBFC MatMul li
066 - | RESULT float32 2:30x256 KGAH Slice sl |
067 - | RESULT float32 2:30x30 KGSP Slice sl |
068 - | RESULT float32 1:1 AAAA Reshape _o |
069 = | RESULT bool 2:30x30 HLZC Equal eq | RESULT bool 2:30x30 HLZC Equal eq
070 = | RESULT float32 3:1x30x30 ???? Where _o | RESULT float32 3:1x30x30 ???? Where _o
071 = | RESULT float32 3:1x30x30 HHHH Softmax so | RESULT float32 3:1x30x30 HHHH Softmax so
072 = | RESULT float32 3:1x30x16 SCAB MatMul ma | RESULT float32 3:1x30x16 SCAB MatMul ma
073 = | RESULT float32 3:1x30x32 VAZA Concat ca | RESULT float32 3:1x30x32 VAZA Concat ca
074 - | RESULT float32 2:32x16 BAAA Transpose _o |
075 = | RESULT float32 3:1x30x16 WAAA MatMul _o | RESULT float32 3:1x30x16 WAAA MatMul _o
076 = | RESULT float32 3:1x30x16 TYYY Add li | RESULT float32 3:1x30x16 TYYY Add li
077 = | RESULT float32 3:1x30x16 OVNA Add ad | RESULT float32 3:1x30x16 OVNA Add ad
078 - | RESULT float32 1:16 EEEE Mul La |
079 - | RESULT float32 1:16 AAAA Mul La |
080 - | RESULT float32 1:16 AAAA Add La |
081 = | RESULT float32 3:1x30x16 ZBBZ LayerNormalizat _o | RESULT float32 3:1x30x16 ZBBZ LayerNormalizat _o
082 - | RESULT float32 2:16x128 AVEA Transpose _o |
083 = | RESULT float32 3:1x30x128 EMOF MatMul _o | RESULT float32 3:1x30x128 EMOF MatMul _o
084 = | RESULT float32 3:1x30x128 PXZQ Add li | RESULT float32 3:1x30x128 PXZQ Add li
085 = | RESULT float32 3:1x30x128 GZAU Relu re | RESULT float32 3:1x30x128 GZAU Relu re
086 - | RESULT float32 2:128x16 BZAA Transpose _o |
087 = | RESULT float32 3:1x30x16 AZBC MatMul _o | RESULT float32 3:1x30x16 AZBC MatMul _o
088 = | RESULT float32 3:1x30x16 AZBC Add li | RESULT float32 3:1x30x16 AZBC Add li
089 = | RESULT float32 3:1x30x16 PUPD Add ou | RESULT float32 3:1x30x16 PUPD Add ou
090 = | OUTPUT float32 3:1x30x16 PUPD ou | OUTPUT float32 3:1x30x16 PUPD ou
The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.
Total running time of the script: (0 minutes 6.291 seconds)