to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
import torch
from onnxruntime import InferenceSession
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.bench_run import max_diff
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.9085302352905273, max=3.940504550933838

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

ep = torch.export.export(llm, (input_ids,))
print(ep.graph)

# Then function :func:`to_onnx <experimental_experiment.torch_interpreter.to_onnx>`
# converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='p_embedding_embedding_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_embedding_pe_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_decoder_norm_1_weight' type=dtype('float32') shape=(16,)
init: name='p_decoder_norm_1_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_attention_attention_0_query_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_0_key_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_0_value_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_1_query_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_1_key_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_attention_1_value_weight' type=dtype('float32') shape=(16, 16)
init: name='p_decoder_attention_linear_weight' type=dtype('float32') shape=(16, 32)
init: name='p_decoder_attention_linear_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_norm_2_weight' type=dtype('float32') shape=(16,)
init: name='p_decoder_norm_2_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_feed_forward_linear_1_weight' type=dtype('float32') shape=(128, 16)
init: name='p_decoder_feed_forward_linear_1_bias' type=dtype('float32') shape=(128,)
init: name='p_decoder_feed_forward_linear_2_weight' type=dtype('float32') shape=(16, 128)
init: name='p_decoder_feed_forward_linear_2_bias' type=dtype('float32') shape=(16,)
init: name='b_decoder_attention_attention_0_mask' type=dtype('float32') shape=(256, 256)
init: name='b_decoder_attention_attention_1_mask' type=dtype('float32') shape=(256, 256)
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.25], dtype=float32)
init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
init: name='init7_s1_30' type=dtype('int64') shape=(1,) -- array([30])
init: name='init1_s_2' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-inf], dtype=float32)
init: name='init1_s16_' type=dtype('float32') shape=(16,)
init: name='init1_s16_2' type=dtype('float32') shape=(16,)
Gather(p_embedding_embedding_weight, input_ids) -> embedding
Gather(p_embedding_pe_weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
Mul(init1_s16_, p_decoder_norm_1_weight) -> LayerNormalizationScalePattern_init1_s16_
Mul(p_decoder_norm_1_weight, init1_s16_2) -> LayerNormalizationScalePattern_init1_s16_2
  Add(LayerNormalizationScalePattern_init1_s16_2, p_decoder_norm_1_bias) -> LayerNormalizationScalePattern_init1_s16_3
  LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
Transpose(p_decoder_attention_attention_0_query_weight, perm=[1,0]) -> _onx_transpose0
  MatMul(_onx_add02, _onx_transpose0) -> linear
Transpose(p_decoder_attention_attention_0_key_weight, perm=[1,0]) -> _onx_transpose02
  MatMul(_onx_add02, _onx_transpose02) -> linear_1
    Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
Transpose(p_decoder_attention_attention_0_value_weight, perm=[1,0]) -> _onx_transpose03
  MatMul(_onx_add02, _onx_transpose03) -> linear_2
Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
  Mul(matmul, _onx_reshape0) -> _onx_mul02
Slice(b_decoder_attention_attention_0_mask, init7_s1_0, init7_s1_30, init7_s1_0) -> slice_1
  Slice(slice_1, init7_s1_0, init7_s1_30, init7_s1_1) -> slice_2
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape02
  Equal(slice_2, _onx_reshape02) -> eq
    Where(eq, init1_s1_3, _onx_mul02) -> _onx_where0
      Softmax(_onx_where0, axis=-1) -> softmax
    MatMul(softmax, linear_2) -> matmul_1
Transpose(p_decoder_attention_attention_1_query_weight, perm=[1,0]) -> _onx_transpose04
  MatMul(_onx_add02, _onx_transpose04) -> linear_3
Transpose(p_decoder_attention_attention_1_key_weight, perm=[1,0]) -> _onx_transpose05
  MatMul(_onx_add02, _onx_transpose05) -> linear_4
    Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
Transpose(p_decoder_attention_attention_1_value_weight, perm=[1,0]) -> _onx_transpose06
  MatMul(_onx_add02, _onx_transpose06) -> linear_5
Reshape(init1_s_, init7_s1_1) -> _onx_reshape03
  Mul(matmul_2, _onx_reshape03) -> _onx_mul03
Slice(b_decoder_attention_attention_1_mask, init7_s1_0, init7_s1_30, init7_s1_0) -> slice_3
  Slice(slice_3, init7_s1_0, init7_s1_30, init7_s1_1) -> slice_4
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape04
  Equal(slice_4, _onx_reshape04) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul03) -> _onx_where02
      Softmax(_onx_where02, axis=-1) -> softmax_1
    MatMul(softmax_1, linear_5) -> matmul_3
      Concat(matmul_1, matmul_3, axis=-1) -> cat
Transpose(p_decoder_attention_linear_weight, perm=[1,0]) -> _onx_transpose07
  MatMul(cat, _onx_transpose07) -> _onx_matmul0
    Add(_onx_matmul0, p_decoder_attention_linear_bias) -> linear_6
    Add(linear_6, add) -> add_1
Mul(init1_s16_, p_decoder_norm_2_weight) -> LayerNormalizationScalePattern_init1_s16_4
Mul(p_decoder_norm_2_weight, init1_s16_2) -> LayerNormalizationScalePattern_init1_s16_5
  Add(LayerNormalizationScalePattern_init1_s16_5, p_decoder_norm_2_bias) -> LayerNormalizationScalePattern_init1_s16_6
  LayerNormalization(add_1, LayerNormalizationScalePattern_init1_s16_4, LayerNormalizationScalePattern_init1_s16_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
Transpose(p_decoder_feed_forward_linear_1_weight, perm=[1,0]) -> _onx_transpose08
  MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul02
    Add(_onx_matmul02, p_decoder_feed_forward_linear_1_bias) -> linear_7
      Relu(linear_7) -> relu
Transpose(p_decoder_feed_forward_linear_2_weight, perm=[1,0]) -> _onx_transpose09
  MatMul(relu, _onx_transpose09) -> _onx_matmul03
    Add(_onx_matmul03, p_decoder_feed_forward_linear_2_bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.9085307121276855, max=3.940504312515259
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=dtype('float32') shape=(1024, 16)
init: name='embedding.pe.weight' type=dtype('float32') shape=(1024, 16)
init: name='mask' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.0.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.value.weight' type=dtype('float32') shape=(16, 16)
init: name='mask2' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.1.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.value.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.linear.weight' type=dtype('float32') shape=(16, 32)
init: name='decoder.feed_forward.linear_1.weight' type=dtype('float32') shape=(128, 16)
init: name='decoder.feed_forward.linear_1.bias' type=dtype('float32') shape=(128,)
init: name='decoder.feed_forward.linear_2.weight' type=dtype('float32') shape=(16, 128)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
  __main__.DecoderLayer[aten_local_function](embedding, mask2, mask, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
  Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
input: 'weight'
input: 'bias'
Constant(value=[1.0, 1.0,...) -> init1_s16_
  Mul(init1_s16_, weight) -> LayerNormalizationScalePattern_init1_s16_
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  Mul(weight, init1_s16_2) -> LayerNormalizationScalePattern_init1_s16_2
    Add(LayerNormalizationScalePattern_init1_s16_2, bias) -> LayerNormalizationScalePattern_init1_s16_3
    LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'weight'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
  MatMul(layer_norm, _onx_transpose0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
Constant(value=0.25) -> init1_s_
Constant(value=[1]) -> init7_s1_1
  Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
Constant(value=[0]) -> init7_s1_0
Constant(value=[30]) -> init7_s1_30
  Slice(mask, init7_s1_0, init7_s1_30, init7_s1_0) -> slice_1
  Slice(slice_1, init7_s1_0, init7_s1_30, init7_s1_1) -> slice_2
Constant(value=0.0) -> init1_s_2
  Reshape(init1_s_2, init7_s1_1) -> _onx_reshape02
    Equal(slice_2, _onx_reshape02) -> eq
Constant(value=[-inf]) -> init1_s1_
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.query.weight) -> query
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.key.weight) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
  MatMul(query, transpose) -> matmul
    Mul(matmul, _onx_reshape0) -> _onx_mul0
  Where(eq, init1_s1_, _onx_mul0) -> _onx_where0
    Softmax(_onx_where0, axis=-1) -> softmax
Linear[aten_local_function](layer_norm, decoder.attention.attention.0.value.weight) -> value
  MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'cat'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
  MatMul(cat, _onx_transpose0) -> _onx_matmul0
    Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: 'mask2'
input: 'mask'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
Constant(value=[-0.150363...) -> decoder.attention.linear.bias
__main__.AttentionBlock[aten_local_function](layer_norm, mask, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, mask2, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight) -> attention_1
  Concat(attention_0, attention_1, axis=-1) -> cat
  Linear_2[aten_local_function](cat, decoder.attention.linear.weight, decoder.attention.linear.bias) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'weight'
input: 'bias'
Transpose(weight, perm=[1,0]) -> _onx_transpose0
  MatMul(layer_norm_1, _onx_transpose0) -> _onx_matmul0
    Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm_1'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.feed_forward.linear_1.bias'
Constant(value=[0.0262538...) -> decoder.feed_forward.linear_2.bias
Linear_3[aten_local_function](layer_norm_1, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> linear_1
  ReLU[aten_local_function](linear_1) -> relu
  Linear_3[aten_local_function](relu, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_2.bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
input: 'mask2'
input: 'mask'
input: 'decoder.feed_forward.linear_2.weight'
input: 'decoder.feed_forward.linear_1.weight'
input: 'decoder.attention.linear.weight'
input: 'decoder.attention.attention.1.value.weight'
input: 'decoder.attention.attention.1.query.weight'
input: 'decoder.attention.attention.1.key.weight'
input: 'decoder.attention.attention.0.value.weight'
input: 'decoder.attention.attention.0.query.weight'
input: 'decoder.attention.attention.0.key.weight'
input: 'decoder.feed_forward.linear_1.bias'
Constant(value=[1.0, 1.0,...) -> decoder.norm_1.weight
Constant(value=[0.0, 0.0,...) -> decoder.norm_1.bias
  LayerNorm[aten_local_function](add, decoder.norm_1.weight, decoder.norm_1.bias) -> norm_1
    __main__.MultiAttentionBlock[aten_local_function](norm_1, mask2, mask, decoder.attention.linear.weight, decoder.attention.attention.1.value.weight, decoder.attention.attention.1.query.weight, decoder.attention.attention.1.key.weight, decoder.attention.attention.0.value.weight, decoder.attention.attention.0.query.weight, decoder.attention.attention.0.key.weight) -> attention
      Add(attention, add) -> add_1
Constant(value=[1.0, 1.0,...) -> decoder.norm_2.weight
Constant(value=[0.0, 0.0,...) -> decoder.norm_2.bias
  LayerNorm[aten_local_function](add_1, decoder.norm_2.weight, decoder.norm_2.bias) -> norm_2
    __main__.FeedForward[aten_local_function](norm_2, decoder.feed_forward.linear_2.weight, decoder.feed_forward.linear_1.weight, decoder.feed_forward.linear_1.bias) -> feed_forward
      Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.9085307121276855, max=3.940504312515259
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules
<Axes: >

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=dtype('float32') shape=(1024, 16)
init: name='embedding.pe.weight' type=dtype('float32') shape=(1024, 16)
init: name='mask' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.0.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.0.value.weight' type=dtype('float32') shape=(16, 16)
init: name='mask2' type=dtype('float32') shape=(256, 256)
init: name='decoder.attention.attention.1.query.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.key.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.attention.1.value.weight' type=dtype('float32') shape=(16, 16)
init: name='decoder.attention.linear.weight' type=dtype('float32') shape=(16, 32)
init: name='decoder.feed_forward.linear_1.weight' type=dtype('float32') shape=(128, 16)
init: name='decoder.feed_forward.linear_1.bias' type=dtype('float32') shape=(128,)
init: name='decoder.feed_forward.linear_2.weight' type=dtype('float32') shape=(16, 128)
Constant(value=[1]) -> init7_s1_1__11
Gather(embedding.embedding.weight, input_ids) -> embedding__1
Gather(embedding.pe.weight, input_ids) -> pe__1
  Add(embedding__1, pe__1) -> embedding
Constant(value=[1.0, 1.0,...) -> decoder.norm_1.weight__4
Constant(value=[0.0, 0.0,...) -> decoder.norm_1.bias__4
Constant(value=[1.0, 1.0,...) -> decoder.norm_2.weight__4
Constant(value=[0.0, 0.0,...) -> decoder.norm_2.bias__4
Constant(value=[1.0, 1.0,...) -> init1_s16___5
  Mul(init1_s16___5, decoder.norm_1.weight__4) -> LayerNormalizationScalePattern_init1_s16___5
Constant(value=[0.0, 0.0,...) -> init1_s16_2__5
  Mul(decoder.norm_1.weight__4, init1_s16_2__5) -> LayerNormalizationScalePattern_init1_s16_2__5
  Add(LayerNormalizationScalePattern_init1_s16_2__5, decoder.norm_1.bias__4) -> LayerNormalizationScalePattern_init1_s16_3__5
    LayerNormalization(embedding, LayerNormalizationScalePattern_init1_s16___5, LayerNormalizationScalePattern_init1_s16_3__5, axis=-1, epsilon=0.00, stash_type=1) -> norm_1__4
Constant(value=[-0.150363...) -> decoder.attention.linear.bias__6
Constant(value=0.25) -> init1_s___7
Constant(value=[1]) -> init7_s1_1__7
  Reshape(init1_s___7, init7_s1_1__7) -> _onx_reshape0__7
Constant(value=[0]) -> init7_s1_0__7
Constant(value=[30]) -> init7_s1_30__7
  Slice(mask, init7_s1_0__7, init7_s1_30__7, init7_s1_0__7) -> slice_1__7
  Slice(slice_1__7, init7_s1_0__7, init7_s1_30__7, init7_s1_1__7) -> slice_2__7
Constant(value=0.0) -> init1_s_2__7
  Reshape(init1_s_2__7, init7_s1_1__7) -> _onx_reshape02__7
    Equal(slice_2__7, _onx_reshape02__7) -> eq__7_2
Constant(value=[-inf]) -> init1_s1___7
Transpose(decoder.attention.attention.0.query.weight, perm=[1,0]) -> _onx_transpose0__8
  MatMul(norm_1__4, _onx_transpose0__8) -> query__7
Transpose(decoder.attention.attention.0.key.weight, perm=[1,0]) -> _onx_transpose0__9
  MatMul(norm_1__4, _onx_transpose0__9) -> key__7
    Transpose(key__7, perm=[0,2,1]) -> transpose__7_1
    MatMul(query__7, transpose__7_1) -> matmul__7
    Mul(matmul__7, _onx_reshape0__7) -> _onx_mul0__7
  Where(eq__7_2, init1_s1___7, _onx_mul0__7) -> _onx_where0__7
    Softmax(_onx_where0__7, axis=-1) -> softmax__7
Transpose(decoder.attention.attention.0.value.weight, perm=[1,0]) -> _onx_transpose0__10
  MatMul(norm_1__4, _onx_transpose0__10) -> value__7
    MatMul(softmax__7, value__7) -> attention_0__6
Constant(value=0.25) -> init1_s___11
  Reshape(init1_s___11, init7_s1_1__11) -> _onx_reshape0__11
Constant(value=[0]) -> init7_s1_0__11
Constant(value=[30]) -> init7_s1_30__11
  Slice(mask2, init7_s1_0__11, init7_s1_30__11, init7_s1_0__11) -> slice_1__11
  Slice(slice_1__11, init7_s1_0__11, init7_s1_30__11, init7_s1_1__11) -> slice_2__11
Constant(value=0.0) -> init1_s_2__11
  Reshape(init1_s_2__11, init7_s1_1__11) -> _onx_reshape02__11
    Equal(slice_2__11, _onx_reshape02__11) -> eq__11_4
Constant(value=[-inf]) -> init1_s1___11
Transpose(decoder.attention.attention.1.query.weight, perm=[1,0]) -> _onx_transpose0__12
  MatMul(norm_1__4, _onx_transpose0__12) -> query__11
Transpose(decoder.attention.attention.1.key.weight, perm=[1,0]) -> _onx_transpose0__13
  MatMul(norm_1__4, _onx_transpose0__13) -> key__11
    Transpose(key__11, perm=[0,2,1]) -> transpose__11_3
    MatMul(query__11, transpose__11_3) -> matmul__11
    Mul(matmul__11, _onx_reshape0__11) -> _onx_mul0__11
  Where(eq__11_4, init1_s1___11, _onx_mul0__11) -> _onx_where0__11
    Softmax(_onx_where0__11, axis=-1) -> softmax__11
Transpose(decoder.attention.attention.1.value.weight, perm=[1,0]) -> _onx_transpose0__14
  MatMul(norm_1__4, _onx_transpose0__14) -> value__11
    MatMul(softmax__11, value__11) -> attention_1__6
      Concat(attention_0__6, attention_1__6, axis=-1) -> cat__6_0
Transpose(decoder.attention.linear.weight, perm=[1,0]) -> _onx_transpose0__15
  MatMul(cat__6_0, _onx_transpose0__15) -> _onx_matmul0__15
  Add(_onx_matmul0__15, decoder.attention.linear.bias__6) -> attention__4
    Add(attention__4, embedding) -> add_1__4
Constant(value=[1.0, 1.0,...) -> init1_s16___16
  Mul(init1_s16___16, decoder.norm_2.weight__4) -> LayerNormalizationScalePattern_init1_s16___16
Constant(value=[0.0, 0.0,...) -> init1_s16_2__16
  Mul(decoder.norm_2.weight__4, init1_s16_2__16) -> LayerNormalizationScalePattern_init1_s16_2__16
  Add(LayerNormalizationScalePattern_init1_s16_2__16, decoder.norm_2.bias__4) -> LayerNormalizationScalePattern_init1_s16_3__16
    LayerNormalization(add_1__4, LayerNormalizationScalePattern_init1_s16___16, LayerNormalizationScalePattern_init1_s16_3__16, axis=-1, epsilon=0.00, stash_type=1) -> norm_2__4
Constant(value=[0.0262538...) -> decoder.feed_forward.linear_2.bias__17
Transpose(decoder.feed_forward.linear_1.weight, perm=[1,0]) -> _onx_transpose0__18
  MatMul(norm_2__4, _onx_transpose0__18) -> _onx_matmul0__18
    Add(_onx_matmul0__18, decoder.feed_forward.linear_1.bias) -> linear_1__17
      Relu(linear_1__17) -> relu__17
Transpose(decoder.feed_forward.linear_2.weight, perm=[1,0]) -> _onx_transpose0__20
  MatMul(relu__17, _onx_transpose0__20) -> _onx_matmul0__20
  Add(_onx_matmul0__20, decoder.feed_forward.linear_2.bias__17) -> feed_forward__4
    Add(feed_forward__4, add_1__4) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2
    ),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder.optimize] start with 75 nodes
[GraphBuilder.optimize] #patterns=51
[GraphBuilder.remove_unused] 4/46remove_initializer:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 5/46remove_initializer:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 6/46remove_initializer:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 7/46remove_initializer:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 8/46remove_initializer:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 9/46remove_initializer:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 10/46remove_initializer:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder.remove_unused] 14/46remove_initializer:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder.remove_unused] 16/46remove_initializer:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder.remove_unused] 18/46remove_initializer:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 19/46remove_initializer:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 23/46remove_initializer:init1_s_:float32[()]
[GraphBuilder.remove_unused] 24/46remove_initializer:init7_s1_1:int64[(1,)]
[GraphBuilder.remove_unused] 25/46remove_initializer:init7_s1_0:int64[(1,)]
[GraphBuilder.remove_unused] 26/46remove_initializer:init7_s1_30:int64[(1,)]
[GraphBuilder.remove_unused] 27/46remove_initializer:init1_s_2:float32[()]
[GraphBuilder.remove_unused] 33/46remove_initializer:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder.remove_unused] 40/46remove_initializer:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization.optimize] start with 53 nodes, 28 initializers, 51 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization.optimize] use pattern   1/51 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   2/51 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   3/51 - P0 - CastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   4/51 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   5/51 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   6/51 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   7/51 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   8/51 - P0 - GeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   9/51 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  10/51 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  11/51 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  12/51 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  13/51 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  14/51 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  15/51 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  16/51 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  17/51 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  18/51 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  19/51 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  20/51 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  21/51 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  22/51 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  23/51 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  24/51 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  25/51 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  26/51 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  27/51 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  28/51 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  29/51 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  30/51 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  31/51 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern  32/51 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  33/51 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  34/51 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  35/51 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  36/51 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern  37/51 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  38/51 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  39/51 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  40/51 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  41/51 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  42/51 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  43/51 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  44/51 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  45/51 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  46/51 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  47/51 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  48/51 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  49/51 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  50/51 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  51/51 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*CastPattern - time=0.014 | max_time=SoftmaxCrossEntropyLossCastPattern:0.004
[GraphBuilderPatternOptimization.optimize] iteration 1: 51 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] increase priority to 1
[GraphBuilderPatternOptimization.optimize] iteration 2: 51 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.009 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 3: 39 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationScalePattern - time=0.008 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 4: 41 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] increase priority to 2
[GraphBuilderPatternOptimization.optimize] iteration 5: 41 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.006 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 6: 37 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] increase priority to 3
[GraphBuilderPatternOptimization.optimize] iteration 7: 37 nodes, priority=3
[GraphBuilderPatternOptimization.optimize] done after 8 iterations with 37 nodes in 0.076
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002351299990550615
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0014661570021416992
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.001612208998267306
    STAT apply_LayerNormalizationScalePattern +8 -6 #it=1 maxmatch=1 i=2 - time=0.001656857999478234
    STAT build_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0034324219996051397
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00028798499988624826
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001610279992746655
    STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.000420615997427376
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005751800017606001
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005249440000625327
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004153130066697486
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0012884470052085817
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005035309950471856
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0010113780044775922
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0005228279951552395
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0007516890036640689
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00047723000170663
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002968260014313273
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000539689001016086
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004626949958037585
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00037213200630503707
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044710899601341225
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0003920269991795067
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0008664269953442272
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.403199950000271e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.986199878156185e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.004822103001060896
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0055566909941262566
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=2.0358995243441314e-05
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031650299933971837
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.008763080000790069
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.000779076995968353
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0005384139985835645
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.005380797992984299
    STAT match_MatMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007936110050650313
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0013552689924836159
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008081600026343949
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004969670007994864
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004736139999295119
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033958600397454575
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.001289054001972545
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000834900994959753
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0005168720017536543
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007207430062408093
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004518020068644546
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005238180019659922
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0017317149977316149
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006951350005692802
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043415199252194725
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.007078238002577564
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004059839993715286
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006120830003055744
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.001375295003526844
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008386039953620639
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000900467002793448
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0006130619985924568
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0006817969988333061
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007598710035381373
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0005397129934863187
    STAT remove_identity_nodes +2 -4 #it=3 maxmatch=0 i=0 - time=0.0008773940026003402
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--
     INPUT:   1 x 7t
    OUTPUT:   1 x 1t
      INIT:  29 x 1t
      INIT:   1 x 7t
      NODE:   8 x Add
      NODE:   1 x Concat
      NODE:   2 x Equal
      NODE:   2 x Gather
      NODE:   2 x LayerNormalization
      NODE:  11 x MatMul
      NODE:   4 x Mul
      NODE:   1 x Relu
      NODE:   2 x Softmax
      NODE:   2 x Where
      NODE:   2 x com.microsoft.FusedMatMul
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   8 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   7 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      INIT:   1 x 7t[1]
      NODE:   2 x Add -SIG- 1t[16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   3 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   2 x LayerNormalization -SIG- 1t[1x30x16], 1t[16], 1t[16]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   4 x Mul -SIG- 1t[16], 1t[16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
[GraphBuilder.remove_unused] 9/30remove_initializer:init7_s1_-1:int64[(1,)]
[GraphBuilder.remove_unused] 10/30remove_initializer:init1_s1_:float32[(1,)]
[GraphBuilder.remove_unused] 11/30remove_initializer:init1_s1_2:float32[(1,)]
[GraphBuilder.remove_unused] 16/30remove_initializer:_onx_reshape0:float32[(1,)]
[GraphBuilder.remove_unused] 22/30remove_initializer:_onx_reshape03:float32[(1,)]
[GraphBuilder.remove_unused] 2/31remove_initializer:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 3/31remove_initializer:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 5/31remove_initializer:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 6/31remove_initializer:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 23/31remove_initializer:init1_s16_:float32[(16,)]
[GraphBuilder.remove_unused] 24/31remove_initializer:init1_s16_2:float32[(16,)]
[GraphBuilder.remove_unused] 26/31remove_initializer:LayerNormalizationScalePattern_init1_s16_2:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 29/31remove_initializer:LayerNormalizationScalePattern_init1_s16_5:torch.float32[torch.Size([16])]
[GraphBuilder.optimize] done with 31 nodes in 0.094
opset: domain='' version=18
opset: domain='com.microsoft' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='p_embedding_embedding_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_embedding_pe_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_decoder_attention_linear_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_feed_forward_linear_1_bias' type=dtype('float32') shape=(128,)
init: name='p_decoder_feed_forward_linear_2_bias' type=dtype('float32') shape=(16,)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-inf], dtype=float32)
init: name='_onx_transpose0' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose02' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose03' type=dtype('float32') shape=(16, 16)
init: name='slice_2' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape02' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose04' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose05' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose06' type=dtype('float32') shape=(16, 16)
init: name='slice_4' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape04' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose07' type=dtype('float32') shape=(32, 16)
init: name='_onx_transpose08' type=dtype('float32') shape=(16, 128)
init: name='_onx_transpose09' type=dtype('float32') shape=(128, 16)
init: name='LayerNormalizationScalePattern_init1_s16_' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_3' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_4' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_6' type=dtype('float32') shape=(16,)
Equal(slice_2, _onx_reshape02) -> eq
Gather(p_embedding_embedding_weight, input_ids) -> embedding
Gather(p_embedding_pe_weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
      MatMul(_onx_add02, _onx_transpose0) -> linear
MatMul(_onx_add02, _onx_transpose02) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul02
  Where(eq, init1_s1_3, _onx_mul02) -> _onx_where0
    Softmax(_onx_where0, axis=-1) -> softmax
MatMul(_onx_add02, _onx_transpose03) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_add02, _onx_transpose04) -> linear_3
MatMul(_onx_add02, _onx_transpose05) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul03
MatMul(_onx_add02, _onx_transpose06) -> linear_5
Equal(slice_4, _onx_reshape04) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul03) -> _onx_where02
    Softmax(_onx_where02, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, _onx_transpose07) -> _onx_matmul0
        Add(_onx_matmul0, p_decoder_attention_linear_bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, LayerNormalizationScalePattern_init1_s16_4, LayerNormalizationScalePattern_init1_s16_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
        MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul02
          Add(_onx_matmul02, p_decoder_feed_forward_linear_1_bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, _onx_transpose09) -> _onx_matmul03
                Add(_onx_matmul03, p_decoder_feed_forward_linear_2_bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=dtype('float32') shape=(1024, 16)
init: name='embedding.pe.weight' type=dtype('float32') shape=(1024, 16)
init: name='_onx_transpose0' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose02' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose03' type=dtype('float32') shape=(16, 16)
init: name='slice_2' type=dtype('float32') shape=(30, 30)
init: name='_onx_transpose04' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose022' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose032' type=dtype('float32') shape=(16, 16)
init: name='slice_4' type=dtype('float32') shape=(30, 30)
init: name='_onx_transpose05' type=dtype('float32') shape=(32, 16)
init: name='decoder.feed_forward.linear_1.bias' type=dtype('float32') shape=(128,)
init: name='_onx_transpose06' type=dtype('float32') shape=(16, 128)
init: name='_onx_transpose023' type=dtype('float32') shape=(128, 16)
__main__.Embedding[aten_local_function](input_ids, embedding.pe.weight, embedding.embedding.weight) -> embedding
  __main__.DecoderLayer[aten_local_function](embedding, _onx_transpose06, _onx_transpose023, slice_4, slice_2, _onx_transpose05, _onx_transpose04, _onx_transpose032, _onx_transpose03, _onx_transpose022, _onx_transpose02, _onx_transpose0, decoder.feed_forward.linear_1.bias) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]
----- function name=Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'input_ids'
input: 'weight'
Gather(weight, input_ids) -> output
output: name='output' type=? shape=?
----- function name=__main__.Embedding domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'input_ids'
input: 'embedding.pe.weight'
input: 'embedding.embedding.weight'
Embedding[aten_local_function](input_ids, embedding.embedding.weight) -> embedding
Embedding[aten_local_function](input_ids, embedding.pe.weight) -> pe
  Add(embedding, pe) -> output
output: name='output' type=? shape=?
----- function name=LayerNorm domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'add'
Constant(value=[1.0, 1.0,...) -> LayerNormalizationScalePattern_init1_s16_
Constant(value=[0.0, 0.0,...) -> LayerNormalizationScalePattern_init1_s16_3
  LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> output
output: name='output' type=? shape=?
----- function name=Linear domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: 'layer_norm'
input: '_onx_transpose0'
MatMul(layer_norm, _onx_transpose0) -> output
output: name='output' type=? shape=?
----- function name=__main__.AttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_2'
input: '_onx_transpose03'
input: '_onx_transpose02'
input: '_onx_transpose0'
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.0]) -> _onx_reshape02
  Equal(slice_2, _onx_reshape02) -> eq
Linear[aten_local_function](layer_norm, _onx_transpose0) -> query
Linear[aten_local_function](layer_norm, _onx_transpose02) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul0
  Where(eq, init1_s1_, _onx_mul0) -> _onx_where0
    Softmax(_onx_where0, axis=-1) -> softmax
Linear[aten_local_function](layer_norm, _onx_transpose03) -> value
  MatMul(softmax, value) -> output
output: name='output' type=? shape=?
----- function name=Linear_2 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'cat'
input: '_onx_transpose0'
input: 'bias'
MatMul(cat, _onx_transpose0) -> _onx_matmul0
  Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.MultiAttentionBlock domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm'
input: 'slice_4'
input: 'slice_2'
input: '_onx_transpose05'
input: '_onx_transpose04'
input: '_onx_transpose032'
input: '_onx_transpose03'
input: '_onx_transpose022'
input: '_onx_transpose02'
input: '_onx_transpose0'
Constant(value=[-0.150363...) -> decoder.attention.linear.bias
__main__.AttentionBlock[aten_local_function](layer_norm, slice_2, _onx_transpose03, _onx_transpose02, _onx_transpose0) -> attention_0
__main__.AttentionBlock[aten_local_function](layer_norm, slice_4, _onx_transpose032, _onx_transpose022, _onx_transpose04) -> attention_1
  Concat(attention_0, attention_1, axis=-1) -> cat
  Linear_2[aten_local_function](cat, _onx_transpose05, decoder.attention.linear.bias) -> output
output: name='output' type=? shape=?
----- function name=Linear_3 domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: '_onx_transpose0'
input: 'bias'
MatMul(layer_norm_1, _onx_transpose0) -> _onx_matmul0
  Add(_onx_matmul0, bias) -> output
output: name='output' type=? shape=?
----- function name=ReLU domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'linear_7'
Relu(linear_7) -> output
output: name='output' type=? shape=?
----- function name=__main__.FeedForward domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'layer_norm_1'
input: '_onx_transpose02'
input: '_onx_transpose0'
input: 'decoder.feed_forward.linear_1.bias'
Constant(value=[0.0262538...) -> decoder.feed_forward.linear_2.bias
Linear_3[aten_local_function](layer_norm_1, _onx_transpose0, decoder.feed_forward.linear_1.bias) -> linear_1
  ReLU[aten_local_function](linear_1) -> relu
  Linear_3[aten_local_function](relu, _onx_transpose02, decoder.feed_forward.linear_2.bias) -> output
output: name='output' type=? shape=?
----- function name=__main__.DecoderLayer domain=aten_local_function
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: 'add'
input: '_onx_transpose06'
input: '_onx_transpose023'
input: 'slice_4'
input: 'slice_2'
input: '_onx_transpose05'
input: '_onx_transpose04'
input: '_onx_transpose032'
input: '_onx_transpose03'
input: '_onx_transpose022'
input: '_onx_transpose02'
input: '_onx_transpose0'
input: 'decoder.feed_forward.linear_1.bias'
LayerNorm[aten_local_function](add) -> norm_1
  __main__.MultiAttentionBlock[aten_local_function](norm_1, slice_4, slice_2, _onx_transpose05, _onx_transpose04, _onx_transpose032, _onx_transpose03, _onx_transpose022, _onx_transpose02, _onx_transpose0) -> attention
    Add(attention, add) -> add_1
      LayerNorm[aten_local_function](add_1) -> norm_2
        __main__.FeedForward[aten_local_function](norm_2, _onx_transpose023, _onx_transpose06, decoder.feed_forward.linear_1.bias) -> feed_forward
      Add(feed_forward, add_1) -> output
output: name='output' type=? shape=?

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder.optimize] start with 75 nodes
[GraphBuilder.optimize] #patterns=51
[GraphBuilder.remove_unused] 4/46remove_initializer:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 5/46remove_initializer:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 6/46remove_initializer:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 7/46remove_initializer:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 8/46remove_initializer:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 9/46remove_initializer:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder.remove_unused] 10/46remove_initializer:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder.remove_unused] 14/46remove_initializer:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder.remove_unused] 16/46remove_initializer:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder.remove_unused] 18/46remove_initializer:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 19/46remove_initializer:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder.remove_unused] 23/46remove_initializer:init1_s_:float32[()]
[GraphBuilder.remove_unused] 24/46remove_initializer:init7_s1_1:int64[(1,)]
[GraphBuilder.remove_unused] 25/46remove_initializer:init7_s1_0:int64[(1,)]
[GraphBuilder.remove_unused] 26/46remove_initializer:init7_s1_30:int64[(1,)]
[GraphBuilder.remove_unused] 27/46remove_initializer:init1_s_2:float32[()]
[GraphBuilder.remove_unused] 33/46remove_initializer:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder.remove_unused] 40/46remove_initializer:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization.optimize] start with 53 nodes, 28 initializers, 51 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization.optimize] use pattern   1/51 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   2/51 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   3/51 - P0 - CastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   4/51 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   5/51 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   6/51 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   7/51 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   8/51 - P0 - GeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern   9/51 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  10/51 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  11/51 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  12/51 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  13/51 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  14/51 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  15/51 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  16/51 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  17/51 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  18/51 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  19/51 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  20/51 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  21/51 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  22/51 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  23/51 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  24/51 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  25/51 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  26/51 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  27/51 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  28/51 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  29/51 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  30/51 - P1 - MatMulAddPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  31/51 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern  32/51 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  33/51 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  34/51 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  35/51 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  36/51 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization.optimize] use pattern  37/51 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  38/51 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  39/51 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  40/51 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  41/51 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  42/51 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  43/51 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  44/51 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  45/51 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  46/51 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  47/51 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  48/51 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  49/51 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization.optimize] use pattern  50/51 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization.optimize] use pattern  51/51 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*CastPattern - time=0.013 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization.optimize] iteration 1: 51 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] increase priority to 1
[GraphBuilderPatternOptimization.optimize] iteration 2: 51 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 3: 39 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*LayerNormalizationScalePattern - time=0.004 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 4: 41 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] increase priority to 2
[GraphBuilderPatternOptimization.optimize] iteration 5: 41 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.007 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 6: 37 nodes, priority=2
[GraphBuilderPatternOptimization.optimize] increase priority to 3
[GraphBuilderPatternOptimization.optimize] iteration 7: 37 nodes, priority=3
[GraphBuilderPatternOptimization.optimize] done after 8 iterations with 37 nodes in 0.059
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00033199400058947504
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.00044636500024353154
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004514319989539217
    STAT apply_LayerNormalizationScalePattern +8 -6 #it=1 maxmatch=1 i=2 - time=0.0007248299989441875
    STAT build_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0022589520012843423
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00031031700200401247
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0009253140015061945
    STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.00036864899811916985
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0004700450044765603
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.000425509009801317
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002755290006462019
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000751970994315343
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00033706999602145515
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0007624669997312594
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00047389600513270125
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000455398992926348
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00040974299918161705
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021438399926410057
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023231499289977364
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00040036199789028615
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00022855099814478308
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030827800583210774
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0003567460007616319
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0007394730018859264
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.263699858915061e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.395899981725961e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0036905640008626506
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.004693301001680084
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=1.3928998669143766e-05
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035539600139600225
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.007192787001258694
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00045746899559162557
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.000513541996042477
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.004610444000718417
    STAT match_MatMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006537789995491039
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0011999769958492834
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000663203994918149
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005202160064072814
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004569890006678179
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030697799593326636
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0012463210005080327
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006260909940465353
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0006521489995066077
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006813899999542627
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004642429994419217
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005402189999585971
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0013755860018136445
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000347028995747678
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003770200055441819
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.006253109993849648
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026055299895233475
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003796259989030659
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0010040179986390285
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009082119977392722
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006176680035423487
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0005275410003378056
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004835229992750101
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004754610054078512
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00039690500125288963
    STAT remove_identity_nodes +2 -4 #it=3 maxmatch=0 i=0 - time=0.0007931020008982159
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--
     INPUT:   1 x 7t
    OUTPUT:   1 x 1t
      INIT:  29 x 1t
      INIT:   1 x 7t
      NODE:   8 x Add
      NODE:   1 x Concat
      NODE:   2 x Equal
      NODE:   2 x Gather
      NODE:   2 x LayerNormalization
      NODE:  11 x MatMul
      NODE:   4 x Mul
      NODE:   1 x Relu
      NODE:   2 x Softmax
      NODE:   2 x Where
      NODE:   2 x com.microsoft.FusedMatMul
--MODEL: 37 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   8 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   7 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      INIT:   1 x 7t[1]
      NODE:   2 x Add -SIG- 1t[16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   3 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   2 x LayerNormalization -SIG- 1t[1x30x16], 1t[16], 1t[16]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   4 x Mul -SIG- 1t[16], 1t[16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
[GraphBuilder.remove_unused] 9/30remove_initializer:init7_s1_-1:int64[(1,)]
[GraphBuilder.remove_unused] 10/30remove_initializer:init1_s1_:float32[(1,)]
[GraphBuilder.remove_unused] 11/30remove_initializer:init1_s1_2:float32[(1,)]
[GraphBuilder.remove_unused] 16/30remove_initializer:_onx_reshape0:float32[(1,)]
[GraphBuilder.remove_unused] 22/30remove_initializer:_onx_reshape03:float32[(1,)]
[GraphBuilder.remove_unused] 2/31remove_initializer:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 3/31remove_initializer:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 5/31remove_initializer:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 6/31remove_initializer:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 23/31remove_initializer:init1_s16_:float32[(16,)]
[GraphBuilder.remove_unused] 24/31remove_initializer:init1_s16_2:float32[(16,)]
[GraphBuilder.remove_unused] 26/31remove_initializer:LayerNormalizationScalePattern_init1_s16_2:torch.float32[torch.Size([16])]
[GraphBuilder.remove_unused] 29/31remove_initializer:LayerNormalizationScalePattern_init1_s16_5:torch.float32[torch.Size([16])]
[GraphBuilder.optimize] done with 31 nodes in 0.074
opset: domain='' version=18
opset: domain='com.microsoft' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='p_embedding_embedding_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_embedding_pe_weight' type=dtype('float32') shape=(1024, 16)
init: name='p_decoder_attention_linear_bias' type=dtype('float32') shape=(16,)
init: name='p_decoder_feed_forward_linear_1_bias' type=dtype('float32') shape=(128,)
init: name='p_decoder_feed_forward_linear_2_bias' type=dtype('float32') shape=(16,)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-inf], dtype=float32)
init: name='_onx_transpose0' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose02' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose03' type=dtype('float32') shape=(16, 16)
init: name='slice_2' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape02' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose04' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose05' type=dtype('float32') shape=(16, 16)
init: name='_onx_transpose06' type=dtype('float32') shape=(16, 16)
init: name='slice_4' type=dtype('float32') shape=(30, 30)
init: name='_onx_reshape04' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='_onx_transpose07' type=dtype('float32') shape=(32, 16)
init: name='_onx_transpose08' type=dtype('float32') shape=(16, 128)
init: name='_onx_transpose09' type=dtype('float32') shape=(128, 16)
init: name='LayerNormalizationScalePattern_init1_s16_' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_3' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_4' type=dtype('float32') shape=(16,)
init: name='LayerNormalizationScalePattern_init1_s16_6' type=dtype('float32') shape=(16,)
Equal(slice_2, _onx_reshape02) -> eq
Gather(p_embedding_embedding_weight, input_ids) -> embedding
Gather(p_embedding_pe_weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, LayerNormalizationScalePattern_init1_s16_, LayerNormalizationScalePattern_init1_s16_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
      MatMul(_onx_add02, _onx_transpose0) -> linear
MatMul(_onx_add02, _onx_transpose02) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul02
  Where(eq, init1_s1_3, _onx_mul02) -> _onx_where0
    Softmax(_onx_where0, axis=-1) -> softmax
MatMul(_onx_add02, _onx_transpose03) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_add02, _onx_transpose04) -> linear_3
MatMul(_onx_add02, _onx_transpose05) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul03
MatMul(_onx_add02, _onx_transpose06) -> linear_5
Equal(slice_4, _onx_reshape04) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul03) -> _onx_where02
    Softmax(_onx_where02, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, _onx_transpose07) -> _onx_matmul0
        Add(_onx_matmul0, p_decoder_attention_linear_bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, LayerNormalizationScalePattern_init1_s16_4, LayerNormalizationScalePattern_init1_s16_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
        MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul02
          Add(_onx_matmul02, p_decoder_feed_forward_linear_1_bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, _onx_transpose09) -> _onx_matmul03
                Add(_onx_matmul03, p_decoder_feed_forward_linear_2_bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(onx, onx_optimized, verbose=1)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 88 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 88 results (first model)
[compare_onnx_execution] got 56 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 90 pairs
[compare_onnx_execution] done
------------
001 = | INITIA float32  2:1024x16            BORW                 p_ | INITIA float32  2:1024x16            BORW                 p_
002 = | INITIA float32  2:1024x16            UUVW                 p_ | INITIA float32  2:1024x16            UUVW                 p_
003 ~ | INITIA float32  1:16                 EEEE                 p_ | INITIA float32  1:16                 AAAA                 p_
004 ~ | INITIA float32  1:16                 AAAA                 p_ | INITIA float32  1:128                AAAA                 p_
005 ~ | INITIA float32  2:16x16              AAAA                 p_ | INITIA float32  1:16                 AAAA                 p_
006 ~ | INITIA float32  2:16x16              AAAA                 p_ | INITIA float32  1:1                  ?AAA                 in
007 - | INITIA float32  2:16x16              YAAC                 p_ |
008 ~ | INITIA float32  2:16x16              AAZB                 p_ | INITIA float32  2:16x16              CAZA                 _o
009 ~ | INITIA float32  2:16x16              AABA                 p_ | INITIA float32  2:16x16              AAAA                 _o
010 ~ | INITIA float32  2:16x16              ZZAA                 p_ | INITIA float32  2:16x16              AAAA                 _o
011 ~ | INITIA float32  2:16x32              AAYA                 p_ | INITIA float32  2:30x30              KGSP                 sl
012 ~ | INITIA float32  1:16                 AAAA                 p_ | INITIA float32  1:1                  AAAA                 _o
013 ~ | INITIA float32  1:16                 EEEE                 p_ | INITIA float32  2:16x16              ZAAA                 _o
014 ~ | INITIA float32  1:16                 AAAA                 p_ | INITIA float32  2:16x16              AAAA                 _o
015 - | INITIA float32  2:128x16             BAAX                 p_ |
016 ~ | INITIA float32  1:128                AAAA                 p_ | INITIA float32  2:16x16              AAZY                 _o
017 - | INITIA float32  2:16x128             AZAA                 p_ |
018 ~ | INITIA float32  1:16                 AAAA                 p_ | INITIA float32  2:30x30              KGSP                 sl
019 - | INITIA float32  2:256x256            AOCQ                 b_ |
020 - | INITIA float32  2:256x256            AOCQ                 b_ |
021 - | INITIA float32                       AAAA                 in |
022 ~ | INITIA int64    1:1                  BAAA                 in | INITIA float32  1:1                  AAAA                 _o
023 ~ | INITIA int64    1:1                  AAAA                 in | INITIA float32  2:32x16              BAAA                 _o
024 + |                                                              | INITIA float32  2:16x128             AVEA                 _o
025 + |                                                              | INITIA float32  2:128x16             BZAA                 _o
026 ~ | INITIA int64    1:1                  EAAA                 in | INITIA float32  1:16                 EEEE                 La
027 - | INITIA float32                       AAAA                 in |
028 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  1:16                 AAAA                 La
029 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 La
030 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 La
031 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
032 = | RESULT float32  3:1x30x16            FDNV Gather          em | RESULT float32  3:1x30x16            FDNV Gather          em
033 = | RESULT float32  3:1x30x16            QUDH Gather          em | RESULT float32  3:1x30x16            QUDH Gather          em
034 = | RESULT float32  3:1x30x16            WYQC Add             ad | RESULT float32  3:1x30x16            WYQC Add             ad
035 - | RESULT float32  1:16                 EEEE Mul             La |
036 - | RESULT float32  1:16                 AAAA Mul             La |
037 - | RESULT float32  1:16                 AAAA Add             La |
038 = | RESULT float32  3:1x30x16            ZBBZ LayerNormalizat _o | RESULT float32  3:1x30x16            ZBBZ LayerNormalizat _o
039 - | RESULT float32  2:16x16              CAZA Transpose       _o |
040 = | RESULT float32  3:1x30x16            MQST MatMul          li | RESULT float32  3:1x30x16            MQST MatMul          li
041 - | RESULT float32  2:16x16              AAAA Transpose       _o |
042 = | RESULT float32  3:1x30x16            VXOC MatMul          li | RESULT float32  3:1x30x16            VXOC MatMul          li
043 - | RESULT float32  2:16x16              AAAA Transpose       _o |
044 ~ | RESULT float32  3:1x30x16            VBUE MatMul          li | RESULT float32  3:1x30x30            YCFE FusedMatMul     _o
045 ~ | RESULT float32  3:1x16x30            WAUT Transpose       tr | RESULT float32  3:1x30x16            VBUE MatMul          li
046 - | RESULT float32  3:1x30x30            QLXR MatMul          ma |
047 - | RESULT float32  1:1                  AAAA Reshape         _o |
048 - | RESULT float32  3:1x30x30            YCFE Mul             _o |
049 - | RESULT float32  2:30x256             KGAH Slice           sl |
050 - | RESULT float32  2:30x30              KGSP Slice           sl |
051 - | RESULT float32  1:1                  AAAA Reshape         _o |
052 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
053 = | RESULT float32  3:1x30x30            ???? Where           _o | RESULT float32  3:1x30x30            ???? Where           _o
054 = | RESULT float32  3:1x30x30            HGHH Softmax         so | RESULT float32  3:1x30x30            HGHH Softmax         so
055 = | RESULT float32  3:1x30x16            DYYY MatMul          ma | RESULT float32  3:1x30x16            DYYY MatMul          ma
056 - | RESULT float32  2:16x16              ZAAA Transpose       _o |
057 = | RESULT float32  3:1x30x16            BLRA MatMul          li | RESULT float32  3:1x30x16            BLRA MatMul          li
058 - | RESULT float32  2:16x16              AAAA Transpose       _o |
059 = | RESULT float32  3:1x30x16            ZQZD MatMul          li | RESULT float32  3:1x30x16            ZQZD MatMul          li
060 - | RESULT float32  2:16x16              AAZY Transpose       _o |
061 - | RESULT float32  3:1x30x16            ZBFC MatMul          li |
062 - | RESULT float32  3:1x16x30            PBAC Transpose       tr |
063 ~ | RESULT float32  3:1x30x30            KHTN MatMul          ma | RESULT float32  3:1x30x30            CWZX FusedMatMul     _o
064 - | RESULT float32  1:1                  AAAA Reshape         _o |
065 ~ | RESULT float32  3:1x30x30            CWZX Mul             _o | RESULT float32  3:1x30x16            ZBFC MatMul          li
066 - | RESULT float32  2:30x256             KGAH Slice           sl |
067 - | RESULT float32  2:30x30              KGSP Slice           sl |
068 - | RESULT float32  1:1                  AAAA Reshape         _o |
069 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
070 = | RESULT float32  3:1x30x30            ???? Where           _o | RESULT float32  3:1x30x30            ???? Where           _o
071 = | RESULT float32  3:1x30x30            HHHH Softmax         so | RESULT float32  3:1x30x30            HHHH Softmax         so
072 = | RESULT float32  3:1x30x16            SCAB MatMul          ma | RESULT float32  3:1x30x16            SCAB MatMul          ma
073 = | RESULT float32  3:1x30x32            VAZA Concat          ca | RESULT float32  3:1x30x32            VAZA Concat          ca
074 - | RESULT float32  2:32x16              BAAA Transpose       _o |
075 = | RESULT float32  3:1x30x16            WAAA MatMul          _o | RESULT float32  3:1x30x16            WAAA MatMul          _o
076 = | RESULT float32  3:1x30x16            TYYY Add             li | RESULT float32  3:1x30x16            TYYY Add             li
077 = | RESULT float32  3:1x30x16            OVNA Add             ad | RESULT float32  3:1x30x16            OVNA Add             ad
078 - | RESULT float32  1:16                 EEEE Mul             La |
079 - | RESULT float32  1:16                 AAAA Mul             La |
080 - | RESULT float32  1:16                 AAAA Add             La |
081 = | RESULT float32  3:1x30x16            ZBBZ LayerNormalizat _o | RESULT float32  3:1x30x16            ZBBZ LayerNormalizat _o
082 - | RESULT float32  2:16x128             AVEA Transpose       _o |
083 = | RESULT float32  3:1x30x128           EMOF MatMul          _o | RESULT float32  3:1x30x128           EMOF MatMul          _o
084 = | RESULT float32  3:1x30x128           PXZQ Add             li | RESULT float32  3:1x30x128           PXZQ Add             li
085 = | RESULT float32  3:1x30x128           GZAU Relu            re | RESULT float32  3:1x30x128           GZAU Relu            re
086 - | RESULT float32  2:128x16             BZAA Transpose       _o |
087 = | RESULT float32  3:1x30x16            AZBC MatMul          _o | RESULT float32  3:1x30x16            AZBC MatMul          _o
088 = | RESULT float32  3:1x30x16            AZBC Add             li | RESULT float32  3:1x30x16            AZBC Add             li
089 = | RESULT float32  3:1x30x16            PUPD Add             ou | RESULT float32  3:1x30x16            PUPD Add             ou
090 = | OUTPUT float32  3:1x30x16            PUPD                 ou | OUTPUT float32  3:1x30x16            PUPD                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 6.291 seconds)

Gallery generated by Sphinx-Gallery