to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.8575398921966553, max=3.8025996685028076

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias, 1e-05, False), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias, 1e-05, False), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.857539653778076, max=3.8025996685028076
max discrepancy=2.384185791015625e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.857539653778076, max=3.8025996685028076
max discrepancy=2.384185791015625e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-EAA.optimize] start with 73 nodes
[GraphBuilder-EAA.optimize] #patterns=121
[GraphBuilder-EAA.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-EAA.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-EAA.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-EAA.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-EAA.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-EAA.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-EAA.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-EAA.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-EAA.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-EAA.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-EAA.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-EAA.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-EAA.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-EAA.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-EAA.optimize] start with 53 nodes, 28 initializers, 121 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   1/121 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   2/121 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   3/121 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   4/121 - P0 - CastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   5/121 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   6/121 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   7/121 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   8/121 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern   9/121 - P0 - FunctionAttentionGQAPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  10/121 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  11/121 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  12/121 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  13/121 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  14/121 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  15/121 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  16/121 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  17/121 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  18/121 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  19/121 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  20/121 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  21/121 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  22/121 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  23/121 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  24/121 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  25/121 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  26/121 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  27/121 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  28/121 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  29/121 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  30/121 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  31/121 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  32/121 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  33/121 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  34/121 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  35/121 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  36/121 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  37/121 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  38/121 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  39/121 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  40/121 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  41/121 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  42/121 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  43/121 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  44/121 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  45/121 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  46/121 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  47/121 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  48/121 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  49/121 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  50/121 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  51/121 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  52/121 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  53/121 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  54/121 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  55/121 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  56/121 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  57/121 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  58/121 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  59/121 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  60/121 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  61/121 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  62/121 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  63/121 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  64/121 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  65/121 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  66/121 - P1 - MissingReduceMaxPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  67/121 - P1 - MissingTopKPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  68/121 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  69/121 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  70/121 - P1 - NotNotPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  71/121 - P1 - NotWherePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  72/121 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  73/121 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  74/121 - P1 - RMSNormalizationMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  75/121 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  76/121 - P1 - ReduceArgTopKPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  77/121 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  78/121 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  79/121 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  80/121 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  81/121 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  82/121 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  83/121 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  84/121 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  85/121 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  86/121 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  87/121 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  88/121 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  89/121 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  90/121 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  91/121 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  92/121 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  93/121 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  94/121 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  95/121 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  96/121 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  97/121 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  98/121 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern  99/121 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 100/121 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 101/121 - P1 - SwapRangeAddScalarPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 102/121 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 103/121 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 104/121 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 105/121 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 106/121 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 107/121 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 108/121 - P1 - WhereAddPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 109/121 - P2 - AttentionGQAPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 110/121 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 111/121 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 112/121 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 113/121 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 114/121 - P2 - GroupQueryAttention3DPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 115/121 - P2 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 116/121 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 117/121 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 118/121 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 119/121 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 120/121 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-EAA.optimize] use pattern 121/121 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-EAA.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-EAA.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-EAA.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.010 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilder-EAA.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-EAA.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-EAA.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-EAA.optimize] increase priority to 1
[GraphBuilderPatternOptimization-EAA.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-EAA.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.007 | max_time=IdentityPattern:0.000
[GraphBuilder-EAA.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-EAA.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-EAA.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-EAA.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-EAA.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-EAA.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-EAA.optimize] increase priority to 2
[GraphBuilderPatternOptimization-EAA.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-EAA.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilder-EAA.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-EAA.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-EAA.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-EAA.optimize] increase priority to 3
[GraphBuilderPatternOptimization-EAA.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-EAA.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-EAA.optimize] done after 8 iterations with 29 nodes in 0.069
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002860429958673194
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007790100062265992
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0004455059897736646
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0007283020022441633
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00033534901012899354
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0017056899887393229
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00015481599984923378
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.1683005141094327e-05
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.000881986998138018
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007480690037482418
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007900379932834767
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007875739975133911
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0010953340024570934
    STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012647089002712164
    STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.004608982002537232
    STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009446561998629477
    STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.008148878005158622
    STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.013961549993837252
    STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00793398599489592
    STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.005819630998303182
    STAT match_AttentionGQAPattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.925799582153559e-05
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0003406409887247719
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002844090122380294
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020976599626010284
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022205498680705205
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005340769930626266
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00031491199479205534
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002316430036444217
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004972940005245619
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0003099699897575192
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019658400560729206
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002669509995030239
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003603480145102367
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00028484001086326316
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002214330088463612
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001876279930002056
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002094449955620803
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00010075999307446182
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00030334298935486004
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001694570019026287
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001868599938461557
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002756060057436116
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018095600535161793
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021817300148541108
    STAT match_FunctionAttentionGQAPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003973829952883534
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0010128219946636818
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032485600240761414
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002314160010428168
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002374159957980737
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022281001292867586
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.622900350019336e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00022438000451074913
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004862389905611053
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.577399810543284e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.985399861354381e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00024274700263049453
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003666188997158315
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005019265998271294
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.975002517923713e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020285699429223314
    STAT match_GroupQueryAttention3DPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00010309099889127538
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.002766459008853417
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00037449800583999604
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021956999989924952
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0034703869896475226
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.667000318178907e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008281329937744886
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020591599604813382
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019272400095360354
    STAT match_MissingReduceMaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019710299966391176
    STAT match_MissingTopKPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001945850090123713
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00046741801634198055
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002488940008333884
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00010643000132404268
    STAT match_NotNotPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018978298612637445
    STAT match_NotWherePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001819200042518787
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033227600215468556
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020351699640741572
    STAT match_RMSNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002063340143649839
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021428400214063004
    STAT match_ReduceArgTopKPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022289999469649047
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024637999740662053
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002065009975922294
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006804819931858219
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.726900129346177e-05
    STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.7182999474462122e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00040938601159723476
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029495898343157023
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00040913900738814846
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027460799901746213
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002946459935628809
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023018800129648298
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005969650010229088
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0010694059965317138
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020498999947449192
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020895700436085463
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028443899645935744
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006300949971773662
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005508530011866242
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023806600802345201
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005076640009065159
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003886499907821417
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005099700065329671
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003352430067025125
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003185010064044036
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00036401101533556357
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003041789968847297
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028991300496272743
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023812198196537793
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003915529960067943
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002451850086799823
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002113519949489273
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018638300389284268
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000189096994290594
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020342000789241865
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004723661004391033
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018962498870678246
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019338700803928077
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004355750061222352
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000280373016721569
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030584100022679195
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027274600142845884
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023766600497765467
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002707110033952631
    STAT match_SwapRangeAddScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022525800159201026
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003967070078942925
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003188519840477966
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005761369975516573
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003185170062351972
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029591500060632825
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011855600314447656
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003086620054091327
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000500725996971596
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004509639911702834
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003127879972453229
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031054199644131586
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00038572499761357903
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000320604995067697
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002773789965431206
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002844379923772067
    STAT match_WhereAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002816110063577071
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=5.9851983678527176e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0023228010104503483
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0025592949968995526
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-EAA.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00021789500169688836s with changed=0 scale=0
[GraphBuilder-EAA.optimize] done with 29 nodes in 0.080
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-KSU.optimize] start with 73 nodes
[GraphBuilder-KSU.optimize] #patterns=121
[GraphBuilder-KSU.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-KSU.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-KSU.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-KSU.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-KSU.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-KSU.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-KSU.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-KSU.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-KSU.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-KSU.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-KSU.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-KSU.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-KSU.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-KSU.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-KSU.optimize] start with 53 nodes, 28 initializers, 121 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   1/121 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   2/121 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   3/121 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   4/121 - P0 - CastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   5/121 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   6/121 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   7/121 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   8/121 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern   9/121 - P0 - FunctionAttentionGQAPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  10/121 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  11/121 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  12/121 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  13/121 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  14/121 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  15/121 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  16/121 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  17/121 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  18/121 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  19/121 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  20/121 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  21/121 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  22/121 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  23/121 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  24/121 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  25/121 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  26/121 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  27/121 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  28/121 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  29/121 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  30/121 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  31/121 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  32/121 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  33/121 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  34/121 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  35/121 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  36/121 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  37/121 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  38/121 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  39/121 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  40/121 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  41/121 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  42/121 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  43/121 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  44/121 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  45/121 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  46/121 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  47/121 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  48/121 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  49/121 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  50/121 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  51/121 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  52/121 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  53/121 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  54/121 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  55/121 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  56/121 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  57/121 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  58/121 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  59/121 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  60/121 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  61/121 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  62/121 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  63/121 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  64/121 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  65/121 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  66/121 - P1 - MissingReduceMaxPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  67/121 - P1 - MissingTopKPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  68/121 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  69/121 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  70/121 - P1 - NotNotPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  71/121 - P1 - NotWherePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  72/121 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  73/121 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  74/121 - P1 - RMSNormalizationMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  75/121 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  76/121 - P1 - ReduceArgTopKPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  77/121 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  78/121 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  79/121 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  80/121 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  81/121 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  82/121 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  83/121 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  84/121 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  85/121 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  86/121 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  87/121 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  88/121 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  89/121 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  90/121 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  91/121 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  92/121 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  93/121 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  94/121 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  95/121 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  96/121 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  97/121 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  98/121 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern  99/121 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 100/121 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 101/121 - P1 - SwapRangeAddScalarPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 102/121 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 103/121 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 104/121 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 105/121 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 106/121 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 107/121 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 108/121 - P1 - WhereAddPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 109/121 - P2 - AttentionGQAPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 110/121 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 111/121 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 112/121 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 113/121 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 114/121 - P2 - GroupQueryAttention3DPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 115/121 - P2 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 116/121 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 117/121 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 118/121 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 119/121 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 120/121 - P3 - ReshapeGemmReshapePattern()
[GraphBuilderPatternOptimization-KSU.optimize] use pattern 121/121 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-KSU.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-KSU.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-KSU.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.011 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilder-KSU.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-KSU.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-KSU.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-KSU.optimize] increase priority to 1
[GraphBuilderPatternOptimization-KSU.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-KSU.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.009 | max_time=SameChildrenPattern:0.000
[GraphBuilder-KSU.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-KSU.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-KSU.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-KSU.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-KSU.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.007 | max_time=IdentityPattern:0.001
[GraphBuilderPatternOptimization-KSU.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-KSU.optimize] increase priority to 2
[GraphBuilderPatternOptimization-KSU.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-KSU.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.007 | max_time=IdentityPattern:0.000
[GraphBuilder-KSU.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-KSU.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-KSU.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-KSU.optimize] increase priority to 3
[GraphBuilderPatternOptimization-KSU.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-KSU.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-KSU.optimize] done after 8 iterations with 29 nodes in 0.078
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0003799129990511574
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0013199890017858706
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00040493699634680524
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0010906219977186993
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00044314600381767377
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.001616968002053909
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0003647859994089231
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=1.496501499786973e-05
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.001112366997404024
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0009510379968560301
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008700979960849509
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007897210016380996
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0015978989904397167
    STAT iteration_0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.013713204003579449
    STAT iteration_1 +0 -0 #it=1 maxmatch=0 i=0 - time=0.004929210001137108
    STAT iteration_2 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012362689005385619
    STAT iteration_3 +0 -0 #it=1 maxmatch=0 i=0 - time=0.009276724995288532
    STAT iteration_4 +0 -0 #it=1 maxmatch=0 i=0 - time=0.012674511002842337
    STAT iteration_5 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00957526899583172
    STAT iteration_6 +0 -0 #it=1 maxmatch=0 i=0 - time=0.007365725999989081
    STAT match_AttentionGQAPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.454598719254136e-05
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00040083199564833194
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0003731440083356574
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020832900918321684
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005890430111321621
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0007036970055196434
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003282509933342226
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002449460080242716
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005555180105147883
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0003622660005930811
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002456459842505865
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003663100069388747
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00041878999763866886
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000358330988092348
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002621790117700584
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021522199676837772
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020380900241434574
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00011477199586806819
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003001199947902933
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001775639902916737
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001969820004887879
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002937300014309585
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019026400696020573
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002177159913117066
    STAT match_FunctionAttentionGQAPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003811080168816261
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0014863910037092865
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000313991004077252
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002463549972162582
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021250700228847563
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002039320024778135
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.83270040503703e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00020081800903426483
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00048233500274363905
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011378400085959584
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.085299825528637e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002626529967528768
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003265564992034342
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003978441010985989
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.243005140684545e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019486799283185974
    STAT match_GroupQueryAttention3DPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00011627699859673157
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0034807730044121854
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00043891699897358194
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023871299345046282
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004011420016468037
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.423200051765889e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.001080175003153272
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002653400006238371
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021976999414619058
    STAT match_MissingReduceMaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022433100821217522
    STAT match_MissingTopKPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025139900390058756
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004922000007354654
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002542769943829626
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0001642420029384084
    STAT match_NotNotPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001914429885800928
    STAT match_NotWherePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018558299052529037
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003982310008723289
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021956599812256172
    STAT match_RMSNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019225399591960013
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020280000171624124
    STAT match_ReduceArgTopKPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022197000362211838
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002507500030333176
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018912400992121547
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007343459947151132
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.997800038428977e-05
    STAT match_ReshapeGemmReshapePattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.693499845918268e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004933130039717071
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030413500644499436
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00045125799806555733
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003016780101461336
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004875790109508671
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022467200324172154
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006565860021510161
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0013903679937357083
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002331449941266328
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021523699251702055
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028336900868453085
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006248749850783497
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000579297004151158
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002832770114764571
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000693285015586298
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032220200955634937
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007400510148727335
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004116850032005459
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003413759986869991
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004142869947827421
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002988049964187667
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003461660089669749
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000301182983093895
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039626099169254303
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00027409198810346425
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028284400468692183
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002075140000670217
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002362770028412342
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002739189949352294
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005445867995149456
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002443630000925623
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000332327006617561
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00048581101145828143
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00032259499130304903
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033409999014111236
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002990829962072894
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002516960012144409
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00039179901068564504
    STAT match_SwapRangeAddScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002471970146871172
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00048597600834909827
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004732509914902039
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006613340010517277
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00036211300175637007
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003489570153760724
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00014889500016579404
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004976970012648962
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004932349911541678
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00048223399790003896
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003626409961725585
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003360579939908348
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000418407995312009
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029855100001441315
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003244840117986314
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004890339841949753
    STAT match_WhereAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034541900822659954
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=7.171400648076087e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002848258001904469
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.002574593003373593
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-KSU.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 29 nodes, 20 initializers
[OrderOptimization.shape_order] done after in 0.00021615599689539522s with changed=0 scale=0
[GraphBuilder-KSU.optimize] done with 29 nodes in 0.094
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              BAAB                 p_
007 ~ | INITIA float32  2:16x16              BAAB                 p_ | INITIA float32  2:16x16              AAAD                 p_
008 ~ | INITIA float32  2:16x16              AAAD                 p_ | INITIA float32  2:16x16              ZYBA                 p_
009 ~ | INITIA float32  2:16x16              ZYBA                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              ABZY                 p_
012 ~ | INITIA float32  2:16x16              ABZY                 p_ | INITIA float32  2:16x16              AAAB                 p_
013 ~ | INITIA float32  2:16x16              AAAB                 p_ | INITIA float32  2:16x16              AYDB                 p_
014 ~ | INITIA float32  2:16x16              AYDB                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              YAAB                 p_ | INITIA float32  2:32x16              YAAB                 p_
016 = | INITIA float32  2:16x128             ZXCD                 p_ | INITIA float32  2:16x128             ZXCD                 p_
017 = | INITIA float32  2:128x16             AACA                 p_ | INITIA float32  2:128x16             AACA                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            EXPH                 em | INITIA float32  2:1024x16            EXPH                 em
021 = | INITIA float32  2:1024x16            PGMP                 em | INITIA float32  2:1024x16            PGMP                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                AAAB                 de | INITIA float32  1:128                AAAB                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            LFVS Gather          em | RESULT float32  3:1x30x16            LFVS Gather          em
030 = | RESULT float32  3:1x30x16            OREZ Gather          em | RESULT float32  3:1x30x16            OREZ Gather          em
031 ~ | RESULT float32  3:1x30x16            ZXAR Add             ad | RESULT float32  3:1x30x16            ZBXD SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            ZBXD LayerNormalizat _o | RESULT float32  3:1x30x1             ABAZ SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            UBZF MatMul          li | RESULT float32  3:1x30x1             GGGE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            AEEC MatMul          li | RESULT float32  3:1x30x16            ZXAR SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            THYD MatMul          li | RESULT float32  3:1x30x16            UBZF MatMul          li
036 ~ | RESULT float32  3:1x16x30            PHII Transpose       tr | RESULT float32  3:1x30x16            AEEC MatMul          li
037 ~ | RESULT float32  3:1x30x30            RSPV MatMul          ma | RESULT float32  3:1x30x30            YEDF FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            YEDF Mul             _o | RESULT float32  3:1x30x16            THYD MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x30            IHHH Softmax         so
043 = | RESULT float32  3:1x30x16            YYAA MatMul          ma | RESULT float32  3:1x30x16            YYAA MatMul          ma
044 = | RESULT float32  3:1x30x16            GCMB MatMul          li | RESULT float32  3:1x30x16            GCMB MatMul          li
045 = | RESULT float32  3:1x30x16            UAXB MatMul          li | RESULT float32  3:1x30x16            UAXB MatMul          li
046 ~ | RESULT float32  3:1x30x16            CFBZ MatMul          li | RESULT float32  3:1x30x30            AXEA FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            URBH Transpose       tr | RESULT float32  3:1x30x16            CFBZ MatMul          li
048 ~ | RESULT float32  3:1x30x30            BOQB MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            AXEA Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IHHH Softmax         so
051 - | RESULT bool     2:30x30              HLZC Equal           eq |
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            QEDB MatMul          ma
053 ~ | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x32            MDDD Concat          ca
054 ~ | RESULT float32  3:1x30x16            QEDB MatMul          ma | RESULT float32  3:1x30x16            VAAA MatMul          _o
055 ~ | RESULT float32  3:1x30x32            MDDD Concat          ca | RESULT float32  3:1x30x16            SXYX Add             li
056 ~ | RESULT float32  3:1x30x16            VAAA MatMul          _o | RESULT float32  3:1x30x16            AAXD SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            SXYX Add             li | RESULT float32  3:1x30x1             ABAZ SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            QUXO Add             ad | RESULT float32  3:1x30x1             GGGE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            AAXD LayerNormalizat _o | RESULT float32  3:1x30x16            QUXO SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           SBXL MatMul          _o | RESULT float32  3:1x30x128           SBXL MatMul          _o
061 = | RESULT float32  3:1x30x128           CNGY Add             li | RESULT float32  3:1x30x128           CNGY Add             li
062 = | RESULT float32  3:1x30x128           GDQJ Relu            re | RESULT float32  3:1x30x128           GDQJ Relu            re
063 = | RESULT float32  3:1x30x16            DGGD MatMul          _o | RESULT float32  3:1x30x16            DGGD MatMul          _o
064 = | RESULT float32  3:1x30x16            EHIE Add             li | RESULT float32  3:1x30x16            EHIE Add             li
065 = | RESULT float32  3:1x30x16            VBES Add             ou | RESULT float32  3:1x30x16            VBES Add             ou
066 = | OUTPUT float32  3:1x30x16            VBES                 ou | OUTPUT float32  3:1x30x16            VBES                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 2.673 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

Gallery generated by Sphinx-Gallery