to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.227273941040039, max=3.6770823001861572

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias, 1e-05, False), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias, 1e-05, False), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.227273941040039, max=3.6770823001861572
max discrepancy=2.384185791015625e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.227273941040039, max=3.6770823001861572
max discrepancy=2.384185791015625e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-OFO.optimize] start with 73 nodes
[GraphBuilder-OFO.optimize] #patterns=111
[GraphBuilder-OFO.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-OFO.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-OFO.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-OFO.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-OFO.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-OFO.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-OFO.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-OFO.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-OFO.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-OFO.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-OFO.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-OFO.optimize] start with 53 nodes, 28 initializers, 111 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   1/111 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   2/111 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   3/111 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   4/111 - P0 - CastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   5/111 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   6/111 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   7/111 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   8/111 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern   9/111 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  10/111 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  11/111 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  12/111 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  13/111 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  14/111 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  15/111 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  16/111 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  17/111 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  18/111 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  19/111 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  20/111 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  21/111 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  22/111 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  23/111 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  24/111 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  25/111 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  26/111 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  27/111 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  28/111 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  29/111 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  30/111 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  31/111 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  32/111 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  33/111 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  34/111 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  35/111 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  36/111 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  37/111 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  38/111 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  39/111 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  40/111 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  41/111 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  42/111 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  43/111 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  44/111 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  45/111 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  46/111 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  47/111 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  48/111 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  49/111 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  50/111 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  51/111 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  52/111 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  53/111 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  54/111 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  55/111 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  56/111 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  57/111 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  58/111 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  59/111 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  60/111 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  61/111 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  62/111 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  63/111 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  64/111 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  65/111 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  66/111 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  67/111 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  68/111 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  69/111 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  70/111 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  71/111 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  72/111 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  73/111 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  74/111 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  75/111 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  76/111 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  77/111 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  78/111 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  79/111 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  80/111 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  81/111 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  82/111 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  83/111 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  84/111 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  85/111 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  86/111 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  87/111 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  88/111 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  89/111 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  90/111 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  91/111 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  92/111 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  93/111 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  94/111 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  95/111 - P1 - SwapRangeAddScalarPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  96/111 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  97/111 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  98/111 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern  99/111 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 100/111 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 101/111 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 102/111 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 103/111 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 104/111 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 105/111 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 106/111 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 107/111 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 108/111 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 109/111 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 110/111 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-OFO.optimize] use pattern 111/111 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-OFO.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-OFO.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-OFO.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.010 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilder-OFO.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-OFO.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-OFO.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-OFO.optimize] increase priority to 1
[GraphBuilderPatternOptimization-OFO.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-OFO.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.007 | max_time=IdentityPattern:0.000
[GraphBuilder-OFO.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-OFO.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-OFO.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-OFO.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-OFO.optimize] increase priority to 2
[GraphBuilderPatternOptimization-OFO.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-OFO.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-OFO.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-OFO.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-OFO.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-OFO.optimize] increase priority to 3
[GraphBuilderPatternOptimization-OFO.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-OFO.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-OFO.optimize] done after 8 iterations with 29 nodes in 0.068
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002796830000306727
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0008231299999579278
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00040775000002213346
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0008430249999946682
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0003837829999611131
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014849640001557418
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00015791499993156322
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.2842999922213494e-05
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008780440000464296
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007604280000350627
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007499829998778296
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007617989999744168
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0011700530001235165
    STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012680772000067009
    STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0043271060000051875
    STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009838464999916141
    STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.008060679999971399
    STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012667783999972926
    STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007314950000022691
    STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.006530704999931913
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0004188979997934439
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.000340633000064372
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002591709998114311
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003446879999273733
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005575349999844548
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003322699999444012
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002226780000000872
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000522155999988172
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00030480999998871994
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019006400020771252
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00027801299995644513
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00032002899979488575
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002825740000389487
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023173900001438597
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017946199977814103
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024100400014503975
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00012077899998530484
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00027817299985599675
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001635399999031506
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018079200015108654
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00026579300015328045
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020278800002415664
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002538889998504601
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00038194100000055187
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031582799988427723
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022269099997629382
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002135419999831356
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002463259999103684
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.0001035910000837248
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0002082470001596448
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0005211210000197752
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010804199996528041
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.274899994376028e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002508210000087274
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0031837300001598123
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00400782200017602
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=1.4601000088987348e-05
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018471200007752486
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.003320490999840331
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00035135500013439014
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022237500002120214
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003458630000068297
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.07420000228376e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008081290000063746
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026140299985399906
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021998199997597112
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043513500008884876
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002459989999579193
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029559799997969094
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003301319998172403
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002050850000614446
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024473600012697716
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002501779999874998
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018214800002169795
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006007420001878927
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7806999923996045e-05
    STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.8638000003411435e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039604199992027134
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027059399997142464
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039018499978737964
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002764679998108477
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002886350001745086
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020024400009788224
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006482430000005479
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0012550730000384647
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019046899990371458
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019230600014452648
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002702659999158641
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005534680001346715
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004929219999212364
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002562129998295859
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00048176499979035725
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003040590003138277
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000511414000015975
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003237330001866212
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000301928000112639
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00036799400015752326
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002717099998790218
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002654090000078213
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003216129999827899
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039327800016053516
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003443280002102256
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002368949999436154
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021866199983833212
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023388499994325684
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000198361999878216
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005661125999949945
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020763200006967963
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002152060001208156
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005040850002160369
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002940769999213444
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003017619998217924
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029393400006938464
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023319600006743713
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028877600004761916
    STAT match_SwapRangeAddScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020586899995578278
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000501877000147033
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004460769998786418
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006232089999684831
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033637699982591585
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003299620001371295
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00013503599996056437
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00036288399985551223
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000487903000021106
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000557096000193269
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000331902000084483
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031846400008817
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004097160001492739
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000298998999937794
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031692700019902986
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028527600022698607
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=6.550699993113085e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002454932999967241
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0023611650001384987
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-OFO.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00024625599996852543s with changed=0 scale=0
[GraphBuilder-OFO.optimize] done with 29 nodes in 0.081
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-BIM.optimize] start with 73 nodes
[GraphBuilder-BIM.optimize] #patterns=111
[GraphBuilder-BIM.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-BIM.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-BIM.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-BIM.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-BIM.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-BIM.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-BIM.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-BIM.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-BIM.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-BIM.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-BIM.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-BIM.optimize] start with 53 nodes, 28 initializers, 111 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   1/111 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   2/111 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   3/111 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   4/111 - P0 - CastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   5/111 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   6/111 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   7/111 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   8/111 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern   9/111 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  10/111 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  11/111 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  12/111 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  13/111 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  14/111 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  15/111 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  16/111 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  17/111 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  18/111 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  19/111 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  20/111 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  21/111 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  22/111 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  23/111 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  24/111 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  25/111 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  26/111 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  27/111 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  28/111 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  29/111 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  30/111 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  31/111 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  32/111 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  33/111 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  34/111 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  35/111 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  36/111 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  37/111 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  38/111 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  39/111 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  40/111 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  41/111 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  42/111 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  43/111 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  44/111 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  45/111 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  46/111 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  47/111 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  48/111 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  49/111 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  50/111 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  51/111 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  52/111 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  53/111 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  54/111 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  55/111 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  56/111 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  57/111 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  58/111 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  59/111 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  60/111 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  61/111 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  62/111 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  63/111 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  64/111 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  65/111 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  66/111 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  67/111 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  68/111 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  69/111 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  70/111 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  71/111 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  72/111 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  73/111 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  74/111 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  75/111 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  76/111 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  77/111 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  78/111 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  79/111 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  80/111 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  81/111 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  82/111 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  83/111 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  84/111 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  85/111 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  86/111 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  87/111 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  88/111 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  89/111 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  90/111 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  91/111 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  92/111 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  93/111 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  94/111 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  95/111 - P1 - SwapRangeAddScalarPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  96/111 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  97/111 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  98/111 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern  99/111 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 100/111 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 101/111 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 102/111 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 103/111 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 104/111 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 105/111 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 106/111 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 107/111 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 108/111 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 109/111 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 110/111 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-BIM.optimize] use pattern 111/111 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-BIM.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-BIM.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-BIM.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.014 | max_time=GeluOrtPattern:0.003
[GraphBuilder-BIM.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-BIM.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-BIM.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-BIM.optimize] increase priority to 1
[GraphBuilderPatternOptimization-BIM.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-BIM.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.011 | max_time=IdentityPattern:0.001
[GraphBuilder-BIM.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-BIM.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-BIM.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.008 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization-BIM.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-BIM.optimize] increase priority to 2
[GraphBuilderPatternOptimization-BIM.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-BIM.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilder-BIM.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-BIM.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-BIM.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-BIM.optimize] increase priority to 3
[GraphBuilderPatternOptimization-BIM.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-BIM.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-BIM.optimize] done after 8 iterations with 29 nodes in 0.083
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0003927450001128818
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0008372890000600819
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0005206710000038584
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0008976920000804967
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0005576270000346994
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0024291549998451956
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00019574699990698718
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.4463999832514673e-05
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010345870000492141
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008410969999204099
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010721369999373564
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0012714620003180244
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0013911949999965145
    STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.017398427999978594
    STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005836232000092423
    STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.014595395000014832
    STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.010892864999959784
    STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.013563587000021471
    STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00798747600003935
    STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.006097939000028418
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005423460000884006
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0003526250000049913
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021156300010716222
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00048122600003353
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0010116020000623394
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000383157000214851
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00030494800000724354
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0006804939999938142
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00040012800013755623
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000268471999902431
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005747600000631792
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0004547399997818502
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00034406099996431294
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00027270700002191006
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002884139998968749
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023516000010204152
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.0001008420001653576
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00035215200000493496
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021424800013392087
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00024107000012918434
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00034454600006483815
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023592499996993865
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019912000004751462
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004915349996963414
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002977619999455783
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028267300012885244
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021592700011296984
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021003400013341889
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.476300001802883e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00019111500000690285
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.000505160999864529
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.24880000259509e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.090999997672043e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003279179999253756
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004492137999932311
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005713610999919183
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=2.005099997859361e-05
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001923699999224482
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0046795250000286615
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0004254409998338815
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022064900008444965
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004560923000212824
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.416199989369488e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009771130002036443
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021876400001019647
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023216500028411247
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047343099981844716
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028930899998158566
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002252620001854666
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009103800001639684
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021284299987200939
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019017100021301303
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030792200016094284
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021390700021584053
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007301869998173061
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.6609999963511655e-05
    STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7394999960961286e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004362900000387526
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029923699992195907
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000371487000165871
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028750000012678356
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005287090002639161
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017884500005038717
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005873640001254898
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0011766749997832449
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019239600021592196
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019355699998868658
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029006000011122524
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005453239999724246
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005247369999779039
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002285820000906824
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005957569999281986
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032299300005433906
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006123439998191316
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032893900026920164
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028429400015284045
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00037383399978807574
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003264350001472849
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002826959997719314
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002345729999433388
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004291999998713436
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002579660001629236
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024846100006925553
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023452700008874672
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018373000000337925
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001962640000101601
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00596758300014244
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019786300003943325
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024092799992558867
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007429560000673519
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00037768300001062016
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034901799995168403
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003510870001264266
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002302809999719102
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003316210001003128
    STAT match_SwapRangeAddScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019076599994605203
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000513451999836434
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004026749999184176
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006422130001055848
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035893599999781145
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003505620001078569
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00023595400000431255
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003478300000097079
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000461320000113119
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00045558099998288526
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00041218200021830853
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003610889998526545
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004339389998904153
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034320500003559573
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033947700012504356
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004482400000824782
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=7.52230002944998e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.003287892000003012
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0033581790000880574
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-BIM.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00022337699999752658s with changed=0 scale=0
[GraphBuilder-BIM.optimize] done with 29 nodes in 0.095
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              ZAAA                 p_
007 ~ | INITIA float32  2:16x16              ZAAA                 p_ | INITIA float32  2:16x16              AABA                 p_
008 ~ | INITIA float32  2:16x16              AABA                 p_ | INITIA float32  2:16x16              AAAZ                 p_
009 ~ | INITIA float32  2:16x16              AAAZ                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              AAAZ                 p_
012 ~ | INITIA float32  2:16x16              AAAZ                 p_ | INITIA float32  2:16x16              AYAA                 p_
013 ~ | INITIA float32  2:16x16              AYAA                 p_ | INITIA float32  2:16x16              AZAA                 p_
014 ~ | INITIA float32  2:16x16              AZAA                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              BCAA                 p_ | INITIA float32  2:32x16              BCAA                 p_
016 = | INITIA float32  2:16x128             AEBU                 p_ | INITIA float32  2:16x128             AEBU                 p_
017 = | INITIA float32  2:128x16             AZAY                 p_ | INITIA float32  2:128x16             AZAY                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            XLXT                 em | INITIA float32  2:1024x16            XLXT                 em
021 = | INITIA float32  2:1024x16            KQDM                 em | INITIA float32  2:1024x16            KQDM                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                AAAA                 de | INITIA float32  1:128                AAAA                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            FTQA Gather          em | RESULT float32  3:1x30x16            FTQA Gather          em
030 = | RESULT float32  3:1x30x16            JHDX Gather          em | RESULT float32  3:1x30x16            JHDX Gather          em
031 ~ | RESULT float32  3:1x30x16            OATX Add             ad | RESULT float32  3:1x30x16            AAXD SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            AAXD LayerNormalizat _o | RESULT float32  3:1x30x1             BZBA SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            WGBZ MatMul          li | RESULT float32  3:1x30x1             GFFE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            GEVQ MatMul          li | RESULT float32  3:1x30x16            OATX SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            WYCX MatMul          li | RESULT float32  3:1x30x16            WGBZ MatMul          li
036 ~ | RESULT float32  3:1x16x30            XSDD Transpose       tr | RESULT float32  3:1x30x16            GEVQ MatMul          li
037 ~ | RESULT float32  3:1x30x30            UAMZ MatMul          ma | RESULT float32  3:1x30x30            ZAXA FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            ZAXA Mul             _o | RESULT float32  3:1x30x16            WYCX MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            HGHH Softmax         so | RESULT float32  3:1x30x30            HGHH Softmax         so
043 = | RESULT float32  3:1x30x16            ZXXZ MatMul          ma | RESULT float32  3:1x30x16            ZXXZ MatMul          ma
044 = | RESULT float32  3:1x30x16            YFZA MatMul          li | RESULT float32  3:1x30x16            YFZA MatMul          li
045 = | RESULT float32  3:1x30x16            AACY MatMul          li | RESULT float32  3:1x30x16            AACY MatMul          li
046 ~ | RESULT float32  3:1x30x16            XFDG MatMul          li | RESULT float32  3:1x30x30            YABZ FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            XGCV Transpose       tr | RESULT float32  3:1x30x16            XFDG MatMul          li
048 ~ | RESULT float32  3:1x30x30            PDHV MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            YABZ Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IHHH Softmax         so
051 - | RESULT bool     2:30x30              HLZC Equal           eq |
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            TZEE MatMul          ma
053 ~ | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x32            RWBC Concat          ca
054 ~ | RESULT float32  3:1x30x16            TZEE MatMul          ma | RESULT float32  3:1x30x16            DZAA MatMul          _o
055 ~ | RESULT float32  3:1x30x32            RWBC Concat          ca | RESULT float32  3:1x30x16            AWWX Add             li
056 ~ | RESULT float32  3:1x30x16            DZAA MatMul          _o | RESULT float32  3:1x30x16            ZBXD SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            AWWX Add             li | RESULT float32  3:1x30x1             BYBA SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            OVPU Add             ad | RESULT float32  3:1x30x1             GFFE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            ZBXD LayerNormalizat _o | RESULT float32  3:1x30x16            OVPU SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           LIFE MatMul          _o | RESULT float32  3:1x30x128           LIFE MatMul          _o
061 = | RESULT float32  3:1x30x128           GFBA Add             li | RESULT float32  3:1x30x128           GFBA Add             li
062 = | RESULT float32  3:1x30x128           WBAP Relu            re | RESULT float32  3:1x30x128           WBAP Relu            re
063 = | RESULT float32  3:1x30x16            MTTU MatMul          _o | RESULT float32  3:1x30x16            MTTU MatMul          _o
064 = | RESULT float32  3:1x30x16            NUUU Add             li | RESULT float32  3:1x30x16            NUUU Add             li
065 = | RESULT float32  3:1x30x16            APIN Add             ou | RESULT float32  3:1x30x16            APIN Add             ou
066 = | OUTPUT float32  3:1x30x16            APIN                 ou | OUTPUT float32  3:1x30x16            APIN                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 2.687 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

Gallery generated by Sphinx-Gallery