to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.608179092407227, max=4.440975189208984

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias, 1e-05, False), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias, 1e-05, False), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.608178615570068, max=4.440975189208984
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.608178615570068, max=4.440975189208984
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-MPC.optimize] start with 73 nodes
[GraphBuilder-MPC.optimize] #patterns=110
[GraphBuilder-MPC.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-MPC.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-MPC.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-MPC.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-MPC.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-MPC.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-MPC.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-MPC.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-MPC.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-MPC.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-MPC.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-MPC.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern   9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  46/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  47/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  48/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  49/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  50/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  51/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  52/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  53/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  54/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  55/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  56/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  57/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  58/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  59/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  60/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  61/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  62/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  63/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  64/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  65/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  66/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  67/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  68/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  69/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  70/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  71/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  72/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  73/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  74/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  75/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  76/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  77/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  78/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  79/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  80/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  81/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  82/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  83/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  84/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  85/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  86/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  87/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  88/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  89/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  90/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  91/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  92/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  93/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  94/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  95/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  96/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  97/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  98/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern  99/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 100/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 101/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 102/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 103/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 104/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 105/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 106/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 107/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 108/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 109/110 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-MPC.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-MPC.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-MPC.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-MPC.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.009 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilder-MPC.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-MPC.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-MPC.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-MPC.optimize] increase priority to 1
[GraphBuilderPatternOptimization-MPC.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-MPC.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.007 | max_time=GeluOrtPattern:0.000
[GraphBuilder-MPC.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-MPC.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-MPC.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-MPC.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-MPC.optimize] increase priority to 2
[GraphBuilderPatternOptimization-MPC.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-MPC.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.007 | max_time=IdentityPattern:0.001
[GraphBuilder-MPC.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-MPC.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-MPC.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-MPC.optimize] increase priority to 3
[GraphBuilderPatternOptimization-MPC.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-MPC.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-MPC.optimize] done after 8 iterations with 29 nodes in 0.062
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002436219983792398
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.001055629996699281
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00028470299730543047
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0006414859999495093
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0002723729958233889
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0015001110004959628
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011994099986623041
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.0631003533490002e-05
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007940730065456592
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.000843846002680948
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007394740023300983
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007384079981420655
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0011613660062721465
    STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.010924564998276765
    STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0034808639975381084
    STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009606732000975171
    STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0055053940013749525
    STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009232807999069337
    STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009346571001515258
    STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007278225999471033
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00030220799817470834
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023898099971120246
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024219399711000733
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022538299526786432
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005251700022199657
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002399180011707358
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019194499691366218
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004212729945720639
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00026861899823416024
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017949100219993852
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002571740042185411
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003091939943260513
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00024347499493160285
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021676699907402508
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018205300511908717
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023227000565384515
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=9.518300066702068e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000253029000305105
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015309299487853423
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001718380044621881
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023593799778609537
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002704960024857428
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017089699758798815
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032928699874901213
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043007500426028855
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019621199317043647
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021913900127401575
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019528599659679458
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00011403099415474571
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00019527900076354854
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004933470045216382
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.759000098099932e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.633699715370312e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002442819968564436
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0026402769981359597
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0034772930048347916
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.499991963617504e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002626389978104271
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0036402740006451495
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00035562499760999344
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020553099966491573
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0036149469997326378
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.949299833853729e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009024710016092286
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017280600150115788
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015789300960022956
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005358339949452784
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025308700423920527
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001983840011234861
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024312499590450898
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015967600484145805
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020447399583645165
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024331200256710872
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017762599964044057
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004994959999748971
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.1007002184633166e-05
    STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7934998797718436e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00037027399594080634
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024220799969043583
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003843319973384496
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026758600142784417
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025412999821128324
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031544900048174895
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000556958002562169
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009793869976419955
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016833399786264636
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002430500026093796
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002759019989753142
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006391040078597143
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005725939990952611
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025077899408643134
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044415500087779947
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003154839978378732
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005192650023673195
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003574580005079042
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002703989994188305
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004020529995614197
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033265500678680837
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026588600303512067
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002063550055027008
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031777500043972395
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00020218500139890239
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001938030000019353
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014733300122315995
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015989899839041755
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001624830038053915
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004558331002044724
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016203400082304142
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016109299394884147
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003721749999385793
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002465739999024663
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000236893003602745
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002213320003647823
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018208700566901825
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025003900009323843
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032406800164608285
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002539649976824876
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047074999747565016
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002660249992914032
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002676740004972089
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011481599722173996
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026388999685877934
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00038005400347174145
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004045329951622989
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026186600007349625
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00045842300096410327
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003965280011470895
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002622110005177092
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002332070034753997
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023478899674955755
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=5.5829001212259755e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002393529997789301
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0025612999925215263
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-MPC.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00034933200004161336s with changed=0 scale=0
[GraphBuilder-MPC.optimize] done with 29 nodes in 0.073
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-FOA.optimize] start with 73 nodes
[GraphBuilder-FOA.optimize] #patterns=110
[GraphBuilder-FOA.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-FOA.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-FOA.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-FOA.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-FOA.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-FOA.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-FOA.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-FOA.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-FOA.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-FOA.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-FOA.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-FOA.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern   9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  46/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  47/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  48/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  49/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  50/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  51/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  52/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  53/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  54/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  55/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  56/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  57/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  58/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  59/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  60/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  61/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  62/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  63/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  64/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  65/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  66/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  67/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  68/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  69/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  70/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  71/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  72/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  73/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  74/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  75/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  76/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  77/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  78/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  79/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  80/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  81/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  82/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  83/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  84/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  85/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  86/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  87/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  88/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  89/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  90/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  91/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  92/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  93/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  94/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  95/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  96/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  97/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  98/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern  99/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 100/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 101/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 102/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 103/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 104/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 105/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 106/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 107/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 108/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 109/110 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-FOA.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-FOA.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-FOA.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-FOA.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.008 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilder-FOA.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-FOA.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-FOA.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-FOA.optimize] increase priority to 1
[GraphBuilderPatternOptimization-FOA.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-FOA.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-FOA.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-FOA.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-FOA.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-FOA.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-FOA.optimize] increase priority to 2
[GraphBuilderPatternOptimization-FOA.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-FOA.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-FOA.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-FOA.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-FOA.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-FOA.optimize] increase priority to 3
[GraphBuilderPatternOptimization-FOA.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-FOA.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-FOA.optimize] done after 8 iterations with 29 nodes in 0.055
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002181369964091573
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007028329964668956
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0003000720025738701
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0007817269979568664
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00026105700453626923
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.001309169991145609
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011582499791984446
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=9.866991604212672e-06
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008425270061707124
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006346559966914356
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006706179992761463
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006892230012454093
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0009456760017201304
    STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0097180170014326
    STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0032827739996719174
    STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007376120996923419
    STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005867957999726059
    STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.010904203998507
    STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0066166539982077666
    STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005185133002669318
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00030789999800617807
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00022889200045028701
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016377499923692085
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003267720021540299
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004484860000957269
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002589309951872565
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017885099805425853
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003952940023737028
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00024390999897150323
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015442700168932788
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023573599173687398
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002687539999897126
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022327400074573234
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001801959988370072
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015055299809318967
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001506930057075806
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.587900108774193e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00021298100182320923
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014216300041880459
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016015299843274988
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002688210006454028
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001510740039520897
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016074200175353326
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003547100022842642
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022341599105857313
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018347299555898644
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014886899953125976
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001650890008022543
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.0002868170013243798
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00015636600073776208
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004014039986941498
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.942399992723949e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.277499753399752e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020329499602667056
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002478865993907675
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003275160001066979
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.413004252361134e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015451700164703652
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0029474940020008944
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00033269300547544844
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020108699391130358
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003262512997025624
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.848300083423965e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000729836003301898
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016591500389040448
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014830600048298948
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00036613499469240196
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020516600125120021
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018753900440060534
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002622419997351244
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001616139998077415
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014659300359198824
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020128100004512817
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014963000285206363
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044195000009494834
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.6620000426191837e-05
    STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.711500201257877e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003197980004188139
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029100999745423906
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028975700479350053
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029472999449353665
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020567100000334904
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001450729978387244
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004614160061464645
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009011590045702178
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001538010037620552
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015636100215488113
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023870299628470093
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044036399413016625
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039383599869324826
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020285500067984685
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000424265002948232
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002908349997596815
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004080829967278987
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030545599656761624
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031663600384490564
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003778260033868719
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025586899937479757
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024193900026148185
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003787440073210746
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026787999377120286
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00020831299480050802
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016877999951248057
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014769299741601571
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014656299754278734
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001638830071897246
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003984964998380747
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014895300409989432
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015272699602064677
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033974399775615893
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002061180020973552
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021407200256362557
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020405499526532367
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017039800150087103
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021828199896845035
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029430000722641125
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023592899742652662
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044643199726124294
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024489700444974005
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023622100343345664
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011218399959034286
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023358099861070514
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035592299900599755
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003567100029613357
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023944900749484077
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023346799571299925
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028098400434828363
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020742199922096916
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020266600404283963
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002037430058408063
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=5.840800440637395e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002167069000279298
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.002404786995612085
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-FOA.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.0003004520003742073s with changed=0 scale=0
[GraphBuilder-FOA.optimize] done with 29 nodes in 0.066
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              AZZA                 p_
007 ~ | INITIA float32  2:16x16              AZZA                 p_ | INITIA float32  2:16x16              CZZA                 p_
008 ~ | INITIA float32  2:16x16              CZZA                 p_ | INITIA float32  2:16x16              CACZ                 p_
009 ~ | INITIA float32  2:16x16              CACZ                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              AZAA                 p_
012 ~ | INITIA float32  2:16x16              AZAA                 p_ | INITIA float32  2:16x16              AAAB                 p_
013 ~ | INITIA float32  2:16x16              AAAB                 p_ | INITIA float32  2:16x16              AABA                 p_
014 ~ | INITIA float32  2:16x16              AABA                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              ZZAA                 p_ | INITIA float32  2:32x16              ZZAA                 p_
016 = | INITIA float32  2:16x128             BCAG                 p_ | INITIA float32  2:16x128             BCAG                 p_
017 = | INITIA float32  2:128x16             AZAB                 p_ | INITIA float32  2:128x16             AZAB                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            HGMR                 em | INITIA float32  2:1024x16            HGMR                 em
021 = | INITIA float32  2:1024x16            DYWT                 em | INITIA float32  2:1024x16            DYWT                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                ABAA                 de | INITIA float32  1:128                ABAA                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            JYWE Gather          em | RESULT float32  3:1x30x16            JYWE Gather          em
030 = | RESULT float32  3:1x30x16            OVZX Gather          em | RESULT float32  3:1x30x16            OVZX Gather          em
031 ~ | RESULT float32  3:1x30x16            XTVB Add             ad | RESULT float32  3:1x30x16            BZYC SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            BZYC LayerNormalizat _o | RESULT float32  3:1x30x1             AAAB SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            AFYA MatMul          li | RESULT float32  3:1x30x1             GFFE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            GESE MatMul          li | RESULT float32  3:1x30x16            XTVB SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            HADC MatMul          li | RESULT float32  3:1x30x16            AFYA MatMul          li
036 ~ | RESULT float32  3:1x16x30            CCFX Transpose       tr | RESULT float32  3:1x30x16            GESE MatMul          li
037 ~ | RESULT float32  3:1x30x30            OXVD MatMul          ma | RESULT float32  3:1x30x30            DAFA FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            DAFA Mul             _o | RESULT float32  3:1x30x16            HADC MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x30            IHHH Softmax         so
043 = | RESULT float32  3:1x30x16            LEED MatMul          ma | RESULT float32  3:1x30x16            LEED MatMul          ma
044 = | RESULT float32  3:1x30x16            AFAY MatMul          li | RESULT float32  3:1x30x16            AFAY MatMul          li
045 = | RESULT float32  3:1x30x16            AZVR MatMul          li | RESULT float32  3:1x30x16            AZVR MatMul          li
046 ~ | RESULT float32  3:1x30x16            WGBU MatMul          li | RESULT float32  3:1x30x30            YWAA FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            YPVD Transpose       tr | RESULT float32  3:1x30x16            WGBU MatMul          li
048 ~ | RESULT float32  3:1x30x30            PHXB MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            YWAA Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IGHH Softmax         so
051 - | RESULT bool     2:30x30              HLZC Equal           eq |
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            CZAA MatMul          ma
053 ~ | RESULT float32  3:1x30x30            IGHH Softmax         so | RESULT float32  3:1x30x32            ODED Concat          ca
054 ~ | RESULT float32  3:1x30x16            CZAA MatMul          ma | RESULT float32  3:1x30x16            CAAA MatMul          _o
055 ~ | RESULT float32  3:1x30x32            ODED Concat          ca | RESULT float32  3:1x30x16            HFDE Add             li
056 ~ | RESULT float32  3:1x30x16            CAAA MatMul          _o | RESULT float32  3:1x30x16            AAYC SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            HFDE Add             li | RESULT float32  3:1x30x1             AAAB SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            DYZG Add             ad | RESULT float32  3:1x30x1             GFFE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            AAYC LayerNormalizat _o | RESULT float32  3:1x30x16            DYZG SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           PQAP MatMul          _o | RESULT float32  3:1x30x128           PQAP MatMul          _o
061 = | RESULT float32  3:1x30x128           BBMB Add             li | RESULT float32  3:1x30x128           BBMB Add             li
062 = | RESULT float32  3:1x30x128           VIUL Relu            re | RESULT float32  3:1x30x128           VIUL Relu            re
063 = | RESULT float32  3:1x30x16            YXAA MatMul          _o | RESULT float32  3:1x30x16            YXAA MatMul          _o
064 = | RESULT float32  3:1x30x16            ZXAA Add             li | RESULT float32  3:1x30x16            ZXAA Add             li
065 = | RESULT float32  3:1x30x16            BVAH Add             ou | RESULT float32  3:1x30x16            BVAH Add             ou
066 = | OUTPUT float32  3:1x30x16            BVAH                 ou | OUTPUT float32  3:1x30x16            BVAH                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 2.205 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

Gallery generated by Sphinx-Gallery