to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.500553607940674, max=5.789700508117676

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.500553607940674, max=5.789700508117676
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[-0.157050...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0367000...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.500553607940674, max=5.789700508117676
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[-0.157050...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0367000...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-ACU.optimize] start with 73 nodes
[GraphBuilder-ACU.optimize] #patterns=102
[GraphBuilder-ACU.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-ACU.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-ACU.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-ACU.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-ACU.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-ACU.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-ACU.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-ACU.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-ACU.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-ACU.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-ACU.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-ACU.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-ACU.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-ACU.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-ACU.optimize] start with 53 nodes, 28 initializers, 102 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   1/102 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   2/102 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   3/102 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   4/102 - P0 - CastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   5/102 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   6/102 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   7/102 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   8/102 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern   9/102 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  10/102 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  11/102 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  12/102 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  13/102 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  14/102 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  15/102 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  16/102 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  17/102 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  18/102 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  19/102 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  20/102 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  21/102 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  22/102 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  23/102 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  24/102 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  25/102 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  26/102 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  27/102 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  28/102 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  29/102 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  30/102 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  31/102 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  32/102 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  33/102 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  34/102 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  35/102 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  36/102 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  37/102 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  38/102 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  39/102 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  40/102 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  41/102 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  42/102 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  43/102 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  44/102 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  45/102 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  46/102 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  47/102 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  48/102 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  49/102 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  50/102 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  51/102 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  52/102 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  53/102 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  54/102 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  55/102 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  56/102 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  57/102 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  58/102 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  59/102 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  60/102 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  61/102 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  62/102 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  63/102 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  64/102 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  65/102 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  66/102 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  67/102 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  68/102 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  69/102 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  70/102 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  71/102 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  72/102 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  73/102 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  74/102 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  75/102 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  76/102 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  77/102 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  78/102 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  79/102 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  80/102 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  81/102 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  82/102 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  83/102 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  84/102 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  85/102 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  86/102 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  87/102 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  88/102 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  89/102 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  90/102 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  91/102 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  92/102 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  93/102 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  94/102 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  95/102 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  96/102 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  97/102 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  98/102 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern  99/102 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern 100/102 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern 101/102 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-ACU.optimize] use pattern 102/102 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-ACU.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-ACU.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-ACU.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.009 | max_time=GeluErfPattern:0.002
[GraphBuilder-ACU.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-ACU.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-ACU.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-ACU.optimize] increase priority to 1
[GraphBuilderPatternOptimization-ACU.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-ACU.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-ACU.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-ACU.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-ACU.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-ACU.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-ACU.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ACU.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-ACU.optimize] increase priority to 2
[GraphBuilderPatternOptimization-ACU.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-ACU.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-ACU.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-ACU.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-ACU.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-ACU.optimize] increase priority to 3
[GraphBuilderPatternOptimization-ACU.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-ACU.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-ACU.optimize] done after 8 iterations with 29 nodes in 0.052
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0001885939964267891
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0006855690007796511
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00027474399757920764
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004809579986613244
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00019513599909259938
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0012872370025434066
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0002791379993141163
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0012117680016672239
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.002238924993434921
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0008783430021139793
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00032548400122323073
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023920000239741057
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001557649993628729
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001429980002285447
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00040759900366538204
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002480109942553099
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001963299982890021
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000386555002478417
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002735839952947572
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016561099982936867
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00024530199880246073
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021343499247450382
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002852400066331029
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000235641000472242
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00016231399786192924
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018237400581710972
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.715200263191946e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002506000018911436
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012828500257455744
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013850899995304644
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023278001026483253
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013484400187735446
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001536089985165745
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002734360059548635
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020203999520163052
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015536100181634538
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014732499403180555
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014243700570659712
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.495099634979852e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00017691700486466289
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0003985959992860444
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.717699943692423e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.748000229592435e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003968451001128415
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003076731001783628
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=9.348001185571775e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001622300005692523
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.002988143991387915
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003055179950024467
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015324100604630075
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0027481260076456238
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.798499816795811e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006229889986570925
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035041399678448215
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001853159956226591
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001838730022427626
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024746000053710304
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001475650024076458
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001459219965909142
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000195075997908134
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015325699496315792
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00040715599971008487
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.4444001610390842e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003118599961453583
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019957200129283592
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027001500347978435
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020701099856523797
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021032499716966413
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013525900067179464
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004365900058473926
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007758529936836567
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014961900160415098
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001554280024720356
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001990289965760894
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00042491499698371626
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00038390200279536657
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001709300049697049
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035198799741920084
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000231715999689186
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004037630023958627
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023134499497245997
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002103160004480742
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027045201204600744
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022192699543666095
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002024099994741846
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018851600179914385
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027797399525297806
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002057369965768885
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017509899771539494
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016801800302346237
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001340469971182756
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001463980006519705
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0035700940061360598
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001478870071878191
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014370800272445194
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000372420996427536
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002093600014632102
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002429120104352478
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020865000260528177
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001644320072955452
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020195399702060968
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00045169399891165085
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022703899230691604
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002247320044261869
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010796599963214248
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034367300395388156
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003635969951574225
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002534320046834182
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023189300554804504
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002838100008375477
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023653299649595283
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021943000319879502
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=5.1017999794567004e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0019527340009517502
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.002112569003656972
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-ACU.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-ACU.optimize] done with 29 nodes in 0.064
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-MNW.optimize] start with 73 nodes
[GraphBuilder-MNW.optimize] #patterns=102
[GraphBuilder-MNW.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-MNW.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-MNW.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-MNW.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-MNW.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-MNW.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-MNW.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-MNW.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-MNW.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-MNW.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-MNW.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-MNW.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-MNW.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-MNW.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-MNW.optimize] start with 53 nodes, 28 initializers, 102 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   1/102 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   2/102 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   3/102 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   4/102 - P0 - CastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   5/102 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   6/102 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   7/102 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   8/102 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern   9/102 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  10/102 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  11/102 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  12/102 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  13/102 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  14/102 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  15/102 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  16/102 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  17/102 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  18/102 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  19/102 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  20/102 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  21/102 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  22/102 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  23/102 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  24/102 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  25/102 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  26/102 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  27/102 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  28/102 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  29/102 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  30/102 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  31/102 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  32/102 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  33/102 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  34/102 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  35/102 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  36/102 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  37/102 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  38/102 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  39/102 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  40/102 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  41/102 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  42/102 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  43/102 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  44/102 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  45/102 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  46/102 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  47/102 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  48/102 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  49/102 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  50/102 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  51/102 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  52/102 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  53/102 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  54/102 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  55/102 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  56/102 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  57/102 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  58/102 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  59/102 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  60/102 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  61/102 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  62/102 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  63/102 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  64/102 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  65/102 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  66/102 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  67/102 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  68/102 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  69/102 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  70/102 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  71/102 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  72/102 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  73/102 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  74/102 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  75/102 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  76/102 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  77/102 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  78/102 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  79/102 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  80/102 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  81/102 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  82/102 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  83/102 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  84/102 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  85/102 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  86/102 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  87/102 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  88/102 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  89/102 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  90/102 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  91/102 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  92/102 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  93/102 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  94/102 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  95/102 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  96/102 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  97/102 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  98/102 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern  99/102 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern 100/102 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern 101/102 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-MNW.optimize] use pattern 102/102 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-MNW.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-MNW.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-MNW.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilder-MNW.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-MNW.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-MNW.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-MNW.optimize] increase priority to 1
[GraphBuilderPatternOptimization-MNW.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-MNW.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-MNW.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-MNW.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-MNW.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-MNW.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-MNW.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-MNW.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-MNW.optimize] increase priority to 2
[GraphBuilderPatternOptimization-MNW.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-MNW.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=GeluOrtPattern:0.000
[GraphBuilder-MNW.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-MNW.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-MNW.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-MNW.optimize] increase priority to 3
[GraphBuilderPatternOptimization-MNW.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-MNW.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-MNW.optimize] done after 8 iterations with 29 nodes in 0.042
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00016328200217685662
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.00047822000124142505
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00023145399609347805
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.00045434599815052934
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00017654499970376492
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010946979964501224
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00013521099754143506
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001169749997643521
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0019305630194139667
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0007083319978846703
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002275240076414775
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00018863400691770948
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001218450051965192
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031873700572759844
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003499200029182248
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00018547300351201557
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015791699479450472
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00032358399766962975
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002001369939534925
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012561300536617637
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021212500359979458
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001941499976965133
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000229769000725355
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001786030043149367
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001349199992546346
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001189299946418032
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=5.9319001593394205e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00017198299974552356
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00010795799971674569
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011969699698965997
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00017331599519820884
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011628199717961252
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012374400466796942
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002522550021240022
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016782300372142345
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001566600039950572
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002141069999197498
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017531200137455016
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=5.47620038560126e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0001260919962078333
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00031501900230068713
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.688600165536627e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.38929998583626e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0018943459981528576
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0024612349916424137
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=6.03299486101605e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012359100219327956
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.002273370006150799
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00026916200295090675
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014736500088474713
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002310866995685501
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.314900045050308e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000530122997588478
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003056379973713774
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016327000412275083
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014574200031347573
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021547900178120472
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012218900155858137
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011870099842781201
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001674599916441366
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011974699737038463
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003544819956005085
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.9148999854223803e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002711950037337374
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023376999524771236
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028334099260973744
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017636999473324977
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001679860033618752
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011106900274171494
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003943030060327146
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006762810044165235
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018846100283553824
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012490400331444107
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00016698700346751139
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003718980005942285
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003586149941838812
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014622800154029392
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030972699823905714
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020080499598407187
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000330789993313374
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019694200091180392
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018088100114255212
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025431899848626927
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017006999769364484
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017395400209352374
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001540119992569089
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002338640006200876
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00014351000572787598
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001363239971396979
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011566499961190857
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012057999992975965
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012339699969743378
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003189224993548123
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012112699550925754
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012160499682067893
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002801240007102024
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001843350000854116
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018335199638386257
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00016874099674168974
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013352900714380667
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00016670000695739873
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003891580090567004
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018675999672268517
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018497000201023184
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.95550006639678e-05
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030761999732931145
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028330699569778517
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019186899589840323
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019036600133404136
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025403100153198466
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017220400332007557
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00016572999811614864
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=4.362199979368597e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.001662053993641166
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0018001399985223543
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-MNW.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-MNW.optimize] done with 29 nodes in 0.052
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              AAAA                 p_
007 ~ | INITIA float32  2:16x16              AAAA                 p_ | INITIA float32  2:16x16              AAAZ                 p_
008 ~ | INITIA float32  2:16x16              AAAZ                 p_ | INITIA float32  2:16x16              AAAA                 p_
009 ~ | INITIA float32  2:16x16              AAAA                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              AAAA                 p_
012 ~ | INITIA float32  2:16x16              AAAA                 p_ | INITIA float32  2:16x16              AABA                 p_
013 ~ | INITIA float32  2:16x16              AABA                 p_ | INITIA float32  2:16x16              AACB                 p_
014 ~ | INITIA float32  2:16x16              AACB                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              AAAA                 p_ | INITIA float32  2:32x16              AAAA                 p_
016 = | INITIA float32  2:16x128             VYBV                 p_ | INITIA float32  2:16x128             VYBV                 p_
017 = | INITIA float32  2:128x16             CABZ                 p_ | INITIA float32  2:128x16             CABZ                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            ACSE                 em | INITIA float32  2:1024x16            ACSE                 em
021 = | INITIA float32  2:1024x16            VYQK                 em | INITIA float32  2:1024x16            VYQK                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                BAAZ                 de | INITIA float32  1:128                BAAZ                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            OIOQ Gather          em | RESULT float32  3:1x30x16            OIOQ Gather          em
030 = | RESULT float32  3:1x30x16            OQQE Gather          em | RESULT float32  3:1x30x16            OQQE Gather          em
031 ~ | RESULT float32  3:1x30x16            CZEU Add             ad | RESULT float32  3:1x30x16            BZDX SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            BZDX LayerNormalizat _o | RESULT float32  3:1x30x1             ZABA SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            SQRW MatMul          li | RESULT float32  3:1x30x1             GFFE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            WDXZ MatMul          li | RESULT float32  3:1x30x16            CZEU SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            BEQG MatMul          li | RESULT float32  3:1x30x16            SQRW MatMul          li
036 ~ | RESULT float32  3:1x16x30            GWZT Transpose       tr | RESULT float32  3:1x30x16            WDXZ MatMul          li
037 ~ | RESULT float32  3:1x30x30            WTFV MatMul          ma | RESULT float32  3:1x30x30            TEBF FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            TEBF Mul             _o | RESULT float32  3:1x30x16            BEQG MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x30            IHHH Softmax         so
043 = | RESULT float32  3:1x30x16            BBZA MatMul          ma | RESULT float32  3:1x30x16            BBZA MatMul          ma
044 = | RESULT float32  3:1x30x16            XAIF MatMul          li | RESULT float32  3:1x30x16            XAIF MatMul          li
045 = | RESULT float32  3:1x30x16            HCAA MatMul          li | RESULT float32  3:1x30x16            HCAA MatMul          li
046 ~ | RESULT float32  3:1x30x16            AXIB MatMul          li | RESULT float32  3:1x30x30            FUYA FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            AIYD Transpose       tr | RESULT float32  3:1x30x16            AXIB MatMul          li
048 ~ | RESULT float32  3:1x30x30            VASB MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            FUYA Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IGHH Softmax         so
051 - | RESULT bool     2:30x30              HLZC Equal           eq |
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            IZBB MatMul          ma
053 ~ | RESULT float32  3:1x30x30            IGHH Softmax         so | RESULT float32  3:1x30x32            JAAB Concat          ca
054 ~ | RESULT float32  3:1x30x16            IZBB MatMul          ma | RESULT float32  3:1x30x16            AACB MatMul          _o
055 ~ | RESULT float32  3:1x30x32            JAAB Concat          ca | RESULT float32  3:1x30x16            WVZY Add             li
056 ~ | RESULT float32  3:1x30x16            AACB MatMul          _o | RESULT float32  3:1x30x16            CYDX SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            WVZY Add             li | RESULT float32  3:1x30x1             ZABA SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            XUCS Add             ad | RESULT float32  3:1x30x1             HFFE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            CYDX LayerNormalizat _o | RESULT float32  3:1x30x16            XUCS SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           UZFK MatMul          _o | RESULT float32  3:1x30x128           UZFK MatMul          _o
061 = | RESULT float32  3:1x30x128           CEMP Add             li | RESULT float32  3:1x30x128           CEMP Add             li
062 = | RESULT float32  3:1x30x128           GDME Relu            re | RESULT float32  3:1x30x128           GDME Relu            re
063 = | RESULT float32  3:1x30x16            DFDG MatMul          _o | RESULT float32  3:1x30x16            DFDG MatMul          _o
064 = | RESULT float32  3:1x30x16            FGFI Add             li | RESULT float32  3:1x30x16            FGFI Add             li
065 = | RESULT float32  3:1x30x16            CAHA Add             ou | RESULT float32  3:1x30x16            CAHA Add             ou
066 = | OUTPUT float32  3:1x30x16            CAHA                 ou | OUTPUT float32  3:1x30x16            CAHA                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 1.917 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

Gallery generated by Sphinx-Gallery