to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.95937180519104, max=3.7651400566101074

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.95937180519104, max=3.7651400566101074
max discrepancy=2.980232238769531e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[-0.146975...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0807273...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.95937180519104, max=3.7651400566101074
max discrepancy=2.980232238769531e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[-0.146975...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0807273...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-KFO.optimize] start with 73 nodes
[GraphBuilder-KFO.optimize] #patterns=106
[GraphBuilder-KFO.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-KFO.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-KFO.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-KFO.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-KFO.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-KFO.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-KFO.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-KFO.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-KFO.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-KFO.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-KFO.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-KFO.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-KFO.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-KFO.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-KFO.optimize] start with 53 nodes, 28 initializers, 106 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   1/106 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   2/106 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   3/106 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   4/106 - P0 - CastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   5/106 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   6/106 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   7/106 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   8/106 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern   9/106 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  10/106 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  11/106 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  12/106 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  13/106 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  14/106 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  15/106 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  16/106 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  17/106 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  18/106 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  19/106 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  20/106 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  21/106 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  22/106 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  23/106 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  24/106 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  25/106 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  26/106 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  27/106 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  28/106 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  29/106 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  30/106 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  31/106 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  32/106 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  33/106 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  34/106 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  35/106 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  36/106 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  37/106 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  38/106 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  39/106 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  40/106 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  41/106 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  42/106 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  43/106 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  44/106 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  45/106 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  46/106 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  47/106 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  48/106 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  49/106 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  50/106 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  51/106 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  52/106 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  53/106 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  54/106 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  55/106 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  56/106 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  57/106 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  58/106 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  59/106 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  60/106 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  61/106 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  62/106 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  63/106 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  64/106 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  65/106 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  66/106 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  67/106 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  68/106 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  69/106 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  70/106 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  71/106 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  72/106 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  73/106 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  74/106 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  75/106 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  76/106 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  77/106 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  78/106 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  79/106 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  80/106 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  81/106 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  82/106 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  83/106 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  84/106 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  85/106 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  86/106 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  87/106 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  88/106 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  89/106 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  90/106 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  91/106 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  92/106 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  93/106 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  94/106 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  95/106 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  96/106 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  97/106 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  98/106 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern  99/106 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 100/106 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 101/106 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 102/106 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 103/106 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 104/106 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 105/106 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-KFO.optimize] use pattern 106/106 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-KFO.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-KFO.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-KFO.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.008 | max_time=GeluErfPattern:0.003
[GraphBuilder-KFO.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-KFO.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-KFO.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-KFO.optimize] increase priority to 1
[GraphBuilderPatternOptimization-KFO.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-KFO.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilder-KFO.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-KFO.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-KFO.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-KFO.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-KFO.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-KFO.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-KFO.optimize] increase priority to 2
[GraphBuilderPatternOptimization-KFO.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-KFO.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.004 | max_time=FusedMatMulPattern:0.000
[GraphBuilder-KFO.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-KFO.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-KFO.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-KFO.optimize] increase priority to 3
[GraphBuilderPatternOptimization-KFO.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-KFO.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-KFO.optimize] done after 8 iterations with 29 nodes in 0.066
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002785400029097218
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007741319968772586
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00038889200004632585
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0015060650002851617
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00031017899891594425
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014830630025244318
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00012837800022680312
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0018849560001399368
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.002748086004430661
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.001344079999398673
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.000367650998668978
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002658359953784384
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020588000552379526
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021321000531315804
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005341340001905337
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00025045299844350666
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020181100262561813
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00044039600106771104
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00025917899620253593
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016240100012510084
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00028977600231883116
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002536450047045946
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00030568099828087725
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002635550044942647
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001995330057980027
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020719700478366576
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=9.988999954657629e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022879500102135353
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013462900824379176
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015620100020896643
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022677900051348843
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016212199989240617
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001895770001283381
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003389129997231066
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029159900441300124
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018057800480164587
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017831399964052252
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018087100397679023
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.77780003065709e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00017039299928001128
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004947740053466987
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.727099702809937e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.007499905535951e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005403458002547268
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00471641200419981
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.246996003435925e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001961150010174606
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0026906130005954765
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003574269940145314
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001786620014172513
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0030168439989211038
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.0001592030021129176
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0008119039994198829
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019058300676988438
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017669900626060553
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005628620019706432
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002456520051055122
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022409500161302276
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002905080036725849
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002696330011531245
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018325200289837085
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002562239969847724
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001659520021348726
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004968519970134366
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.086400127154775e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035739799204748124
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002575290054664947
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004358690021035727
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002493479987606406
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025033700876520015
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001530049994471483
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00046703199768671766
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0008721639969735406
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001857490060501732
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021576600192929618
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024353500339202583
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005697879969375208
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005392069979279768
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023950500326463953
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00048178599899983965
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028379300056258217
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005178149986022618
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030305500331451185
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026816099853022024
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004312529999879189
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002665720021468587
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024544600455556065
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002796680018946063
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003499750018818304
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002292460012540687
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022236900258576497
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016901699927984737
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001559469965286553
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016798100114101544
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004816328000742942
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017110099724959582
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018360700414632447
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004737270028272178
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024199500694521703
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025151298905257136
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002393769937043544
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020243100152583793
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002555760074756108
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003808049950748682
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005266939915600233
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003053330037801061
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003021220072696451
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.028500062413514e-05
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027568199220695533
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00044922200686414726
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004107210006623063
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002709809996304102
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027245400269748643
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003635100001702085
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002451539949106518
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002463490018271841
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=6.431300062104128e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0028951890017197
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.002540664998377906
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-KFO.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-KFO.optimize] done with 29 nodes in 0.075
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-FPO.optimize] start with 73 nodes
[GraphBuilder-FPO.optimize] #patterns=106
[GraphBuilder-FPO.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-FPO.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-FPO.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-FPO.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-FPO.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-FPO.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-FPO.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-FPO.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-FPO.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-FPO.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-FPO.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-FPO.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-FPO.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-FPO.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-FPO.optimize] start with 53 nodes, 28 initializers, 106 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   1/106 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   2/106 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   3/106 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   4/106 - P0 - CastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   5/106 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   6/106 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   7/106 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   8/106 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern   9/106 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  10/106 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  11/106 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  12/106 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  13/106 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  14/106 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  15/106 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  16/106 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  17/106 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  18/106 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  19/106 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  20/106 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  21/106 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  22/106 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  23/106 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  24/106 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  25/106 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  26/106 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  27/106 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  28/106 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  29/106 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  30/106 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  31/106 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  32/106 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  33/106 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  34/106 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  35/106 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  36/106 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  37/106 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  38/106 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  39/106 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  40/106 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  41/106 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  42/106 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  43/106 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  44/106 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  45/106 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  46/106 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  47/106 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  48/106 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  49/106 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  50/106 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  51/106 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  52/106 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  53/106 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  54/106 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  55/106 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  56/106 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  57/106 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  58/106 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  59/106 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  60/106 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  61/106 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  62/106 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  63/106 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  64/106 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  65/106 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  66/106 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  67/106 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  68/106 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  69/106 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  70/106 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  71/106 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  72/106 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  73/106 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  74/106 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  75/106 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  76/106 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  77/106 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  78/106 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  79/106 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  80/106 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  81/106 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  82/106 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  83/106 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  84/106 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  85/106 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  86/106 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  87/106 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  88/106 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  89/106 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  90/106 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  91/106 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  92/106 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  93/106 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  94/106 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  95/106 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  96/106 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  97/106 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  98/106 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern  99/106 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 100/106 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 101/106 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 102/106 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 103/106 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 104/106 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 105/106 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-FPO.optimize] use pattern 106/106 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-FPO.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-FPO.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-FPO.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-FPO.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-FPO.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-FPO.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-FPO.optimize] increase priority to 1
[GraphBuilderPatternOptimization-FPO.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-FPO.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-FPO.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-FPO.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-FPO.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-FPO.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-FPO.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-FPO.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-FPO.optimize] increase priority to 2
[GraphBuilderPatternOptimization-FPO.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-FPO.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilder-FPO.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-FPO.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-FPO.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-FPO.optimize] increase priority to 3
[GraphBuilderPatternOptimization-FPO.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-FPO.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-FPO.optimize] done after 8 iterations with 29 nodes in 0.044
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00017202099843416363
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0005579899989243131
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00023630599753232673
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.000537899999471847
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00019060899649048224
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0011437820030550938
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011912700210814364
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0011138089939777274
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0018423359979351517
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0008034869970288128
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023073900229064748
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002151320040866267
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012728099682135507
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002545080060372129
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00036941500002285466
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00019891400006599724
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015315999917220324
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003397430082259234
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00022529001216753386
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013039699842920527
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023192400476546027
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002057129968306981
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023179999698186293
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002350469985685777
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001600580035301391
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013062000289210118
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.278900036704727e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00018660299610928632
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011235800411668606
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012769000022672117
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00018578299568616785
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012580500333569944
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013303499872563407
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002628740039654076
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017684399790596217
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013646400111611
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001261689976672642
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012542700278572738
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=5.941900235484354e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00014820000433246605
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0003002160010510124
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.742400051327422e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.8935999075183645e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002151086002413649
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0025342359986098018
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=6.514001142932102e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013118899732944556
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.002193951000663219
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00027334400510881096
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017215699699590914
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0024741810047999024
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.6019001931417733e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005461620021378621
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014203500177245587
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012892800077679567
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030774500555708073
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017452199608669616
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015140199684537947
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002606270063552074
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013233100253273733
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012168200191808864
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017286499860347249
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012978700033272617
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000404534992412664
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.908699778141454e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027877999673364684
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018734899276751094
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026159600383834913
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018597899907035753
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017516399748274125
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001232190006703604
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003958829984185286
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007327210005314555
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013332800153875723
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018917099805548787
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022639700546278618
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00046021999514778145
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047993599946494214
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016357399363187142
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003427889969316311
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002292140015924815
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00038937800127314404
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021407100211945362
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019645299835246988
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026497799626667984
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000201145994651597
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001878119983302895
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016049400073825382
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002375310032221023
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00016928599507082254
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001575109999976121
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011852199895656668
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013406200014287606
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014330899284686893
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0031689979950897396
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014387599730980583
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012893999883090146
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003248449975217227
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018913799794972874
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019609399896580726
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018715799524215981
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014732100680703297
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020292599583626725
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030974799665273167
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00038851499994052574
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019630999668152072
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020051300089107826
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.67820019973442e-05
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021364199710660614
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030286299443105236
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030248800248955376
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002059720063698478
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020195500474073924
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000233776998356916
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018095699488185346
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001759690021572169
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=4.594799975166097e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0016766880034992937
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0016561860029469244
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-FPO.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-FPO.optimize] done with 29 nodes in 0.052
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              AAAB                 p_
007 ~ | INITIA float32  2:16x16              AAAB                 p_ | INITIA float32  2:16x16              AAZA                 p_
008 ~ | INITIA float32  2:16x16              AAZA                 p_ | INITIA float32  2:16x16              AABB                 p_
009 ~ | INITIA float32  2:16x16              AABB                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              AABA                 p_
012 ~ | INITIA float32  2:16x16              AABA                 p_ | INITIA float32  2:16x16              BAAB                 p_
013 ~ | INITIA float32  2:16x16              BAAB                 p_ | INITIA float32  2:16x16              AAAZ                 p_
014 ~ | INITIA float32  2:16x16              AAAZ                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              ZAYA                 p_ | INITIA float32  2:32x16              ZAYA                 p_
016 = | INITIA float32  2:16x128             DWYA                 p_ | INITIA float32  2:16x128             DWYA                 p_
017 = | INITIA float32  2:128x16             ABBB                 p_ | INITIA float32  2:128x16             ABBB                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            ELUA                 em | INITIA float32  2:1024x16            ELUA                 em
021 = | INITIA float32  2:1024x16            OYWC                 em | INITIA float32  2:1024x16            OYWC                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                AZAA                 de | INITIA float32  1:128                AZAA                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            CDVC Gather          em | RESULT float32  3:1x30x16            CDVC Gather          em
030 = | RESULT float32  3:1x30x16            DOMZ Gather          em | RESULT float32  3:1x30x16            DOMZ Gather          em
031 ~ | RESULT float32  3:1x30x16            FSHB Add             ad | RESULT float32  3:1x30x16            CYAA SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            CYAA LayerNormalizat _o | RESULT float32  3:1x30x1             AACA SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            TNMX MatMul          li | RESULT float32  3:1x30x1             FGGE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            AXIT MatMul          li | RESULT float32  3:1x30x16            FSHB SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            WBNA MatMul          li | RESULT float32  3:1x30x16            TNMX MatMul          li
036 ~ | RESULT float32  3:1x16x30            KPFS Transpose       tr | RESULT float32  3:1x30x16            AXIT MatMul          li
037 ~ | RESULT float32  3:1x30x30            AADC MatMul          ma | RESULT float32  3:1x30x30            AAIU FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            AAIU Mul             _o | RESULT float32  3:1x30x16            WBNA MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            HGHH Softmax         so | RESULT float32  3:1x30x30            HGHH Softmax         so
043 = | RESULT float32  3:1x30x16            RZAX MatMul          ma | RESULT float32  3:1x30x16            RZAX MatMul          ma
044 = | RESULT float32  3:1x30x16            FYQB MatMul          li | RESULT float32  3:1x30x16            FYQB MatMul          li
045 = | RESULT float32  3:1x30x16            AAHJ MatMul          li | RESULT float32  3:1x30x16            AAHJ MatMul          li
046 ~ | RESULT float32  3:1x30x16            CUET MatMul          li | RESULT float32  3:1x30x30            BMWM FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            LYKY Transpose       tr | RESULT float32  3:1x30x16            CUET MatMul          li
048 ~ | RESULT float32  3:1x30x30            FYIZ MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            BMWM Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IHHH Softmax         so
051 ~ | RESULT bool     2:30x30              HLZC Equal           eq | RESULT float32  3:1x30x16            EYZZ MatMul          ma
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x32            VXYW Concat          ca
053 ~ | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x16            ZAAY MatMul          _o
054 ~ | RESULT float32  3:1x30x16            EYZZ MatMul          ma | RESULT float32  3:1x30x16            AAAZ Add             li
055 - | RESULT float32  3:1x30x32            VXYW Concat          ca |
056 ~ | RESULT float32  3:1x30x16            ZAAY MatMul          _o | RESULT float32  3:1x30x16            BZAA SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            AAAZ Add             li | RESULT float32  3:1x30x1             AACA SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            ETHA Add             ad | RESULT float32  3:1x30x1             FGGE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            BZAA LayerNormalizat _o | RESULT float32  3:1x30x16            ETHA SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           UQAV MatMul          _o | RESULT float32  3:1x30x128           UQAV MatMul          _o
061 = | RESULT float32  3:1x30x128           IDOJ Add             li | RESULT float32  3:1x30x128           IDOJ Add             li
062 = | RESULT float32  3:1x30x128           IDVX Relu            re | RESULT float32  3:1x30x128           IDVX Relu            re
063 = | RESULT float32  3:1x30x16            DDGD MatMul          _o | RESULT float32  3:1x30x16            DDGD MatMul          _o
064 = | RESULT float32  3:1x30x16            HHLH Add             li | RESULT float32  3:1x30x16            HHLH Add             li
065 = | RESULT float32  3:1x30x16            MASH Add             ou | RESULT float32  3:1x30x16            MASH Add             ou
066 = | OUTPUT float32  3:1x30x16            MASH                 ou | OUTPUT float32  3:1x30x16            MASH                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 1.738 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

Gallery generated by Sphinx-Gallery