to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.998725414276123, max=4.258999824523926

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.998725414276123, max=4.258999824523926
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[0.0858735...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0444768...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.998725414276123, max=4.258999824523926
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[0.0858735...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0444768...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-GMA.optimize] start with 73 nodes
[GraphBuilder-GMA.optimize] #patterns=102
[GraphBuilder-GMA.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-GMA.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-GMA.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-GMA.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-GMA.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-GMA.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-GMA.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-GMA.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-GMA.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-GMA.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-GMA.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-GMA.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-GMA.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-GMA.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-GMA.optimize] start with 53 nodes, 28 initializers, 102 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   1/102 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   2/102 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   3/102 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   4/102 - P0 - CastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   5/102 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   6/102 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   7/102 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   8/102 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern   9/102 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  10/102 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  11/102 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  12/102 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  13/102 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  14/102 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  15/102 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  16/102 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  17/102 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  18/102 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  19/102 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  20/102 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  21/102 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  22/102 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  23/102 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  24/102 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  25/102 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  26/102 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  27/102 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  28/102 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  29/102 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  30/102 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  31/102 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  32/102 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  33/102 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  34/102 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  35/102 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  36/102 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  37/102 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  38/102 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  39/102 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  40/102 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  41/102 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  42/102 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  43/102 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  44/102 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  45/102 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  46/102 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  47/102 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  48/102 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  49/102 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  50/102 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  51/102 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  52/102 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  53/102 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  54/102 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  55/102 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  56/102 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  57/102 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  58/102 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  59/102 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  60/102 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  61/102 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  62/102 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  63/102 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  64/102 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  65/102 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  66/102 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  67/102 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  68/102 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  69/102 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  70/102 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  71/102 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  72/102 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  73/102 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  74/102 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  75/102 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  76/102 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  77/102 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  78/102 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  79/102 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  80/102 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  81/102 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  82/102 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  83/102 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  84/102 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  85/102 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  86/102 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  87/102 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  88/102 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  89/102 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  90/102 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  91/102 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  92/102 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  93/102 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  94/102 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  95/102 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  96/102 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  97/102 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  98/102 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern  99/102 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern 100/102 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern 101/102 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-GMA.optimize] use pattern 102/102 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-GMA.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-GMA.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-GMA.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.010 | max_time=LeakyReluPattern:0.003
[GraphBuilder-GMA.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-GMA.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-GMA.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-GMA.optimize] increase priority to 1
[GraphBuilderPatternOptimization-GMA.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-GMA.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.006 | max_time=IdentityPattern:0.000
[GraphBuilder-GMA.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-GMA.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-GMA.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-GMA.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-GMA.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-GMA.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-GMA.optimize] increase priority to 2
[GraphBuilderPatternOptimization-GMA.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-GMA.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-GMA.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-GMA.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-GMA.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-GMA.optimize] increase priority to 3
[GraphBuilderPatternOptimization-GMA.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-GMA.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-GMA.optimize] done after 8 iterations with 29 nodes in 0.060
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0002548019947425928
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007752280034765135
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00041153999700327404
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0006076949975977186
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00026523300039116293
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014584560012735892
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00018032500156550668
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0015305269989767112
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0025215689893229865
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.001121686007536482
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00034897699879365973
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00026353199427830987
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022426499708672054
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016807999782031402
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004953660027240403
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002619019942358136
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020189000133541413
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00045002899423707277
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002913589996751398
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017342099818051793
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00029383999935816973
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00025210600142600015
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000302759995975066
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00025515599554637447
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001963620015885681
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017219099754584022
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.58470011735335e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002423879959678743
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021300300795701332
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017015399862430058
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00028002400722471066
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016239699471043423
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017317499441560358
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034142199729103595
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023599599444423802
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017807899712352082
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001708150011836551
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000168914000823861
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.982599345268682e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0002139409989467822
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00042650699833757244
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010567799836280756
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.107000055839308e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0028459889981604647
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0036983079990022816
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=9.153001883532852e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017349699555779807
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0028878520024591126
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00034472600600565784
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020343200230854563
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004980925001291325
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.433100108755752e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007248189904203173
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004042430009576492
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023299100212170742
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020878500072285533
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026768699535750784
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002134649948857259
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016336599946953356
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002677409975149203
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001737529964884743
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004960179940098897
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=3.3243999496335164e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004038720035168808
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025304600058007054
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003353100037202239
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025261400514864363
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023754600260872394
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000169676000950858
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000524634997418616
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009108060039579868
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019250199693487957
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001852469977166038
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002475830042385496
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005157309933565557
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005205150009714998
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000209633002668852
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004917379992548376
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028449999808799475
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004999150041840039
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002887250011553988
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002650119968166109
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003407519980100915
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000288792001811089
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002568949967098888
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022044599973014556
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033386200084351003
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002091530004690867
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020164599845884368
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016313800369971432
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021438199837575667
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018183500651502982
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004370969003502978
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017279399980907328
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017860799562186003
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004235870001139119
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025361599909956567
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003358920002938248
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029166000240365975
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019867499577230774
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002531950012780726
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005215269993641414
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002659130004758481
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002637619945744518
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010730900248745456
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00042982000013580546
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004512730010901578
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029771200206596404
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002935839911515359
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000301863005006453
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025926200032699853
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002488610080035869
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=6.202600343385711e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.002284174996020738
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0025287460011895746
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-GMA.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-GMA.optimize] done with 29 nodes in 0.072
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-SUY.optimize] start with 73 nodes
[GraphBuilder-SUY.optimize] #patterns=102
[GraphBuilder-SUY.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-SUY.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-SUY.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-SUY.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-SUY.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-SUY.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-SUY.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-SUY.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-SUY.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-SUY.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-SUY.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-SUY.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-SUY.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-SUY.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-SUY.optimize] start with 53 nodes, 28 initializers, 102 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   1/102 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   2/102 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   3/102 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   4/102 - P0 - CastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   5/102 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   6/102 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   7/102 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   8/102 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern   9/102 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  10/102 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  11/102 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  12/102 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  13/102 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  14/102 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  15/102 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  16/102 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  17/102 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  18/102 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  19/102 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  20/102 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  21/102 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  22/102 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  23/102 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  24/102 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  25/102 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  26/102 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  27/102 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  28/102 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  29/102 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  30/102 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  31/102 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  32/102 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  33/102 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  34/102 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  35/102 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  36/102 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  37/102 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  38/102 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  39/102 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  40/102 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  41/102 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  42/102 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  43/102 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  44/102 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  45/102 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  46/102 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  47/102 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  48/102 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  49/102 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  50/102 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  51/102 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  52/102 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  53/102 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  54/102 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  55/102 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  56/102 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  57/102 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  58/102 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  59/102 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  60/102 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  61/102 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  62/102 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  63/102 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  64/102 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  65/102 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  66/102 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  67/102 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  68/102 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  69/102 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  70/102 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  71/102 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  72/102 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  73/102 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  74/102 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  75/102 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  76/102 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  77/102 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  78/102 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  79/102 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  80/102 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  81/102 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  82/102 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  83/102 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  84/102 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  85/102 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  86/102 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  87/102 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  88/102 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  89/102 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  90/102 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  91/102 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  92/102 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  93/102 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  94/102 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  95/102 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  96/102 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  97/102 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  98/102 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern  99/102 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern 100/102 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern 101/102 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-SUY.optimize] use pattern 102/102 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-SUY.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-SUY.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-SUY.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.010 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilder-SUY.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-SUY.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-SUY.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-SUY.optimize] increase priority to 1
[GraphBuilderPatternOptimization-SUY.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-SUY.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.007 | max_time=IdentityPattern:0.000
[GraphBuilder-SUY.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-SUY.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-SUY.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-SUY.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-SUY.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-SUY.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-SUY.optimize] increase priority to 2
[GraphBuilderPatternOptimization-SUY.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-SUY.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-SUY.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-SUY.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-SUY.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-SUY.optimize] increase priority to 3
[GraphBuilderPatternOptimization-SUY.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-SUY.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-SUY.optimize] done after 8 iterations with 29 nodes in 0.071
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0004001300003437791
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0007625609978276771
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.0005434930026240181
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.001142518001870485
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00023992300339159556
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0018302060052519664
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00020480400053202175
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0022286399944277946
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0029277169960550964
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0013579099977505393
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00042388300062157214
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002763859993137885
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020505299835349433
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004974879993824288
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005873189948033541
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00030627899468527175
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021563500558841042
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0005033180023019668
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0003308420018584002
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002083729996229522
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00037799599886056967
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00029884099785704166
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0003885359983542003
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00028518099497887306
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023479500305256806
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023952700576046482
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00012263599637662992
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00026782700297189876
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016643699564156123
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018564300262369215
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00031291599952965043
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001643510040594265
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000211926999327261
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00045577200216939673
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002869649979402311
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021156100046937354
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021713599562644958
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020413500533322804
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=0.00011613699825829826
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00026942299882648513
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0006320649990811944
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011712299965438433
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00012966299982508644
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003621885000029579
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004384632000437705
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=9.997002052841708e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019815799896605313
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.003103918999840971
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00040421200174023397
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002265149960294366
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0035590259976743255
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.0001184169996122364
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0009496050042798743
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005370149992813822
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002621729981910903
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026368799444753677
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047615700168535113
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029637700208695605
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020704700364149176
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025196600108756684
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019581999731599353
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005473920027725399
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.377200068323873e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00046146200475050136
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029131999690434895
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004819069981749635
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003047139980481006
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003608069964684546
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020573699293890968
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006103899977460969
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0011404469951230567
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022927100144443102
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024330299493158236
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00029985799847054295
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006251760023587849
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005647260004479904
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002626890018291306
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005325899983290583
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003419999957259279
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006673770076304208
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033527699270052835
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030704099117428996
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00039616600042791106
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00030002400308148935
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00031352799851447344
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002929720030806493
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004557960019155871
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002719270050874911
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029399800405371934
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022404299670597538
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002147269988199696
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022955799795454368
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.005404869003541535
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022914999863132834
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023763300487189554
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005084000003989786
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003009449974342715
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002979319942824077
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002884050009015482
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023460100055672228
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028297199241933413
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006111690017860383
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003203919986844994
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032130200270330533
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.0001943259994732216
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005349279999791179
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005746090027969331
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033283300581388175
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000331339004333131
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003782540043175686
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034255099672009237
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00034368700289633125
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=7.413700586766936e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0026438160020916257
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0026407670011394657
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-SUY.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-SUY.optimize] done with 29 nodes in 0.086
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              BBBA                 p_
007 ~ | INITIA float32  2:16x16              BBBA                 p_ | INITIA float32  2:16x16              AYAA                 p_
008 ~ | INITIA float32  2:16x16              AYAA                 p_ | INITIA float32  2:16x16              ABAZ                 p_
009 ~ | INITIA float32  2:16x16              ABAZ                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              AZAA                 p_
012 ~ | INITIA float32  2:16x16              AZAA                 p_ | INITIA float32  2:16x16              ZZAZ                 p_
013 ~ | INITIA float32  2:16x16              ZZAZ                 p_ | INITIA float32  2:16x16              AAZY                 p_
014 ~ | INITIA float32  2:16x16              AAZY                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              ABZA                 p_ | INITIA float32  2:32x16              ABZA                 p_
016 = | INITIA float32  2:16x128             YCXW                 p_ | INITIA float32  2:16x128             YCXW                 p_
017 = | INITIA float32  2:128x16             AAZB                 p_ | INITIA float32  2:128x16             AAZB                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            EYBT                 em | INITIA float32  2:1024x16            EYBT                 em
021 = | INITIA float32  2:1024x16            WEII                 em | INITIA float32  2:1024x16            WEII                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                AAAA                 de | INITIA float32  1:128                AAAA                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            IKKC Gather          em | RESULT float32  3:1x30x16            IKKC Gather          em
030 = | RESULT float32  3:1x30x16            ZHUS Gather          em | RESULT float32  3:1x30x16            ZHUS Gather          em
031 ~ | RESULT float32  3:1x30x16            GRDU Add             ad | RESULT float32  3:1x30x16            YCAA SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            YCAA LayerNormalizat _o | RESULT float32  3:1x30x1             AAAA SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            YAXD MatMul          li | RESULT float32  3:1x30x1             FHGE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            FCFB MatMul          li | RESULT float32  3:1x30x16            GRDU SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            UWBO MatMul          li | RESULT float32  3:1x30x16            YAXD MatMul          li
036 ~ | RESULT float32  3:1x16x30            ZEVQ Transpose       tr | RESULT float32  3:1x30x16            FCFB MatMul          li
037 ~ | RESULT float32  3:1x30x30            GLET MatMul          ma | RESULT float32  3:1x30x30            BCVE FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            BCVE Mul             _o | RESULT float32  3:1x30x16            UWBO MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            HGHH Softmax         so | RESULT float32  3:1x30x30            HGHH Softmax         so
043 = | RESULT float32  3:1x30x16            SVYX MatMul          ma | RESULT float32  3:1x30x16            SVYX MatMul          ma
044 = | RESULT float32  3:1x30x16            KFEW MatMul          li | RESULT float32  3:1x30x16            KFEW MatMul          li
045 = | RESULT float32  3:1x30x16            GCXZ MatMul          li | RESULT float32  3:1x30x16            GCXZ MatMul          li
046 ~ | RESULT float32  3:1x30x16            DDAG MatMul          li | RESULT float32  3:1x30x30            VVEY FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            YZWN Transpose       tr | RESULT float32  3:1x30x16            DDAG MatMul          li
048 ~ | RESULT float32  3:1x30x30            EETP MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            VVEY Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            HHHH Softmax         so
051 - | RESULT bool     2:30x30              HLZC Equal           eq |
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            XFBD MatMul          ma
053 ~ | RESULT float32  3:1x30x30            HHHH Softmax         so | RESULT float32  3:1x30x32            NBZA Concat          ca
054 ~ | RESULT float32  3:1x30x16            XFBD MatMul          ma | RESULT float32  3:1x30x16            CAAA MatMul          _o
055 ~ | RESULT float32  3:1x30x32            NBZA Concat          ca | RESULT float32  3:1x30x16            CAAB Add             li
056 ~ | RESULT float32  3:1x30x16            CAAA MatMul          _o | RESULT float32  3:1x30x16            YCAA SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            CAAB Add             li | RESULT float32  3:1x30x1             AAAA SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            IQEV Add             ad | RESULT float32  3:1x30x1             FHGE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            YCAA LayerNormalizat _o | RESULT float32  3:1x30x16            IQEV SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           DEOA MatMul          _o | RESULT float32  3:1x30x128           DEOA MatMul          _o
061 = | RESULT float32  3:1x30x128           RUBP Add             li | RESULT float32  3:1x30x128           RUBP Add             li
062 = | RESULT float32  3:1x30x128           ECRE Relu            re | RESULT float32  3:1x30x128           ECRE Relu            re
063 = | RESULT float32  3:1x30x16            BADW MatMul          _o | RESULT float32  3:1x30x16            BADW MatMul          _o
064 = | RESULT float32  3:1x30x16            BAEW Add             li | RESULT float32  3:1x30x16            BAEW Add             li
065 = | RESULT float32  3:1x30x16            KQIR Add             ou | RESULT float32  3:1x30x16            KQIR Add             ou
066 = | OUTPUT float32  3:1x30x16            KQIR                 ou | OUTPUT float32  3:1x30x16            KQIR                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 2.525 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a model with a test

to_onnx and a model with a test

Gallery generated by Sphinx-Gallery