to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.1491379737854, max=4.618813514709473

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='SliceSlicePattern_init7_s1_0_start' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilder.constant_folding.from/fold(init7_s1_0)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilder.constant_folding.from/fold(init7_s1_30)##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
    Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
  MatMul(linear_3, transpose_1) -> matmul_2
    Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.1491379737854, max=4.618813514709473
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[0.0547443...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0113851...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.1491379737854, max=4.618813514709473
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, weight::T10) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> init1_s_::RSh1
Constant(value=[0.0]) -> init1_s_2::RSh1
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, init1_s_2::RSh1) -> eq
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
  Where(eq, init1_s1_, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> init1_s_::RSh12
Constant(value=[0.0]) -> init1_s_2::RSh12
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, init1_s_2::RSh12) -> eq2
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, init1_s_::RSh12) -> _onx_mul_matmul2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, weight::T1032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
Constant(value=[0.0547443...) -> bias
  Add(_onx_matmul_cat, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
      Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, weight::T1023) -> _onx_matmul_relu
Constant(value=[0.0113851...) -> bias2
  Add(_onx_matmul_relu, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-CQY.optimize] start with 73 nodes
[GraphBuilder-CQY.optimize] #patterns=92
[GraphBuilder-CQY.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-CQY.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-CQY.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-CQY.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-CQY.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-CQY.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-CQY.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-CQY.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-CQY.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-CQY.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-CQY.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-CQY.optimize] start with 53 nodes, 28 initializers, 92 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   1/92 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   2/92 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   3/92 - P0 - CastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   4/92 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   5/92 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   6/92 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   7/92 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   8/92 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern   9/92 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  10/92 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  11/92 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  12/92 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  13/92 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  14/92 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  15/92 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  16/92 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  17/92 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  18/92 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  19/92 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  20/92 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  21/92 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  22/92 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  23/92 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  24/92 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  25/92 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  26/92 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  27/92 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  28/92 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  29/92 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  30/92 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  31/92 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  32/92 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  33/92 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  34/92 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  35/92 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  36/92 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  37/92 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  38/92 - P1 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  39/92 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  40/92 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  41/92 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  42/92 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  43/92 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  44/92 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  45/92 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  46/92 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  47/92 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  48/92 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  49/92 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  50/92 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  51/92 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  52/92 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  53/92 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  54/92 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  55/92 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  56/92 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  57/92 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  58/92 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  59/92 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  60/92 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  61/92 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  62/92 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  63/92 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  64/92 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  65/92 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  66/92 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  67/92 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  68/92 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  69/92 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  70/92 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  71/92 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  72/92 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  73/92 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  74/92 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  75/92 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  76/92 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  77/92 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  78/92 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  79/92 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  80/92 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  81/92 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  82/92 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  83/92 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  84/92 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  85/92 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  86/92 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  87/92 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  88/92 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  89/92 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  90/92 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  91/92 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-CQY.optimize] use pattern  92/92 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-CQY.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-CQY.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-CQY.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-CQY.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-CQY.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-CQY.optimize] increase priority to 1
[GraphBuilderPatternOptimization-CQY.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-CQY.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-CQY.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-CQY.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-CQY.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-CQY.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-CQY.optimize] increase priority to 2
[GraphBuilderPatternOptimization-CQY.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-CQY.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilder-CQY.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-CQY.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-CQY.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-CQY.optimize] increase priority to 3
[GraphBuilderPatternOptimization-CQY.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-CQY.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-CQY.optimize] done after 8 iterations with 29 nodes in 0.035
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00014969099902373273
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0004874780006502988
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00026582700229482725
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004067619993293192
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00016573200082348194
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0009638310002628714
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00012380699990899302
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001013454999338137
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008280439978989307
    STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.132400004484225e-05
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00021630700030073058
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00018701699809753336
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012142100058554206
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001161640029749833
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00034005799898295663
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014117699720372912
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003156380025757244
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.000210778001928702
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.000125485999888042
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020566500097629614
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00019073399926128332
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023604700254509225
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001860960019257618
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014363699847308453
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001273160014534369
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001797380009520566
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011473199992906302
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011958300274272915
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00019102400074189063
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011557599827938247
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001246270003321115
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001639339989196742
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011912799891433679
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012439199235814158
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011644099868135527
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=5.650899947795551e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00012853699990955647
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00029809500119881704
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.78490001114551e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.3609001042786986e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.001709799000309431
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0020909830018354114
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=6.7550026869867e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012081099885108415
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0017317369984084507
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00025622099929023534
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014180600192048587
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0021210819995758357
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.162699926586356e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005160870005056495
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002892239990615053
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001922980009112507
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001770939979905961
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011977900066995062
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000114618000225164
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016320499889843632
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012286600394872949
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003638580001279479
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.9344999600434676e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026351499946031254
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020138600120844785
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023555299958388787
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017352400027448311
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016819400298118126
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013320899779500905
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005386399971030187
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012667300325119868
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001282800003536977
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017276099970331416
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00043448399810586125
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032721499701438006
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015029099813546054
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030225499904190656
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019854699712595902
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003244519994041184
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020433000099728815
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018761299907055218
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023127600252337288
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001735459973133402
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023400800455419812
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00015971799984981772
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012341500041657127
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001424369984306395
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001272450008400483
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002935329999672831
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001218049965245882
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012243399760336615
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027939600113313645
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001865780013758922
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001748650029185228
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013498599946615286
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035561500408221036
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018729499970504548
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000184643000466167
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.041100020112935e-05
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002750849998847116
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002799320027406793
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001963239992619492
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019400299970584456
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020028600192745216
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017426000340492465
    STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0006548620003741235
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0019251579979027156
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-CQY.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-CQY.optimize] done with 29 nodes in 0.043
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-IMQ.optimize] start with 73 nodes
[GraphBuilder-IMQ.optimize] #patterns=92
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-IMQ.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-IMQ.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-IMQ.optimize] start with 53 nodes, 28 initializers, 92 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   1/92 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   2/92 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   3/92 - P0 - CastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   4/92 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   5/92 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   6/92 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   7/92 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   8/92 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern   9/92 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  10/92 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  11/92 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  12/92 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  13/92 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  14/92 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  15/92 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  16/92 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  17/92 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  18/92 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  19/92 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  20/92 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  21/92 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  22/92 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  23/92 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  24/92 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  25/92 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  26/92 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  27/92 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  28/92 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  29/92 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  30/92 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  31/92 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  32/92 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  33/92 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  34/92 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  35/92 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  36/92 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  37/92 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  38/92 - P1 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  39/92 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  40/92 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  41/92 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  42/92 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  43/92 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  44/92 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  45/92 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  46/92 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  47/92 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  48/92 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  49/92 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  50/92 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  51/92 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  52/92 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  53/92 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  54/92 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  55/92 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  56/92 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  57/92 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  58/92 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  59/92 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  60/92 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  61/92 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  62/92 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  63/92 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  64/92 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  65/92 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  66/92 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  67/92 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  68/92 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  69/92 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  70/92 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  71/92 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  72/92 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  73/92 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  74/92 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  75/92 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  76/92 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  77/92 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  78/92 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  79/92 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  80/92 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  81/92 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  82/92 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  83/92 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  84/92 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  85/92 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  86/92 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  87/92 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  88/92 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  89/92 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  90/92 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  91/92 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] use pattern  92/92 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-IMQ.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.005 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-IMQ.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-IMQ.optimize] increase priority to 1
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-IMQ.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-IMQ.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-IMQ.optimize] increase priority to 2
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-IMQ.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=ShapeBasedExpandBroadcastPattern:0.000
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-IMQ.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-IMQ.optimize] increase priority to 3
[GraphBuilderPatternOptimization-IMQ.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-IMQ.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-IMQ.optimize] done after 8 iterations with 29 nodes in 0.038
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00014954299876990262
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0005257770008029183
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00021109499539306853
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0003963510007451987
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0001277029987249989
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0009879899989755359
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010936199942079838
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.001020333000269602
    STAT check_pattern_B0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.000959601997237769
    STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.802400043350644e-05
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023113700262911152
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0001938599980348954
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014631800149800256
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002773910000541946
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00036972099769627675
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017115199989348184
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003359580023243325
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00021366299915825948
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013018200115766376
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00027450100242276676
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00021505200311366934
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002359770005568862
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001970469984371448
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001500750022387365
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001401189983880613
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00017816000399761833
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011399300092307385
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012389000039547682
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00018147999981010798
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012342700392764527
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013117599883116782
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002151770022464916
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012704699656751473
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013118899914843496
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012282100033189636
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.013999700371642e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0001327500012848759
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0003122279995295685
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.7900000936351717e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.283999962557573e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0018348789999436121
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0024085169989120914
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.433000064338557e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012835999586968683
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.002097357997627114
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002549539985921001
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015446699944732245
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0020164580000709975
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.646799945679959e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005335019995982293
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030084299760346767
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017048499830707442
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022499699844047427
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012595100270118564
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00029390700001385994
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016994800353131723
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000127850998978829
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003754720000870293
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.03159997909097e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027197800045541953
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000184544001967879
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002436279974062927
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019114799943054095
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017875799676403403
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00011973899927397724
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005322800025169272
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013045500236330554
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001320669998676749
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017762900097295642
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039467299939133227
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005254629959381418
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015358899872808252
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003153239958919585
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022683799943479244
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003335820001666434
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023295500068343244
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019475999943097122
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023828400298953056
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018373899729340337
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022436499966715928
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0001598860017111292
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013215099897934124
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000124246000268613
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012736800090351608
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0037673959996027406
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014079099855734967
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001317490005021682
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003830739951808937
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021182200180192012
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001850079988798825
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014682700020784978
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00039244999607035425
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022648400226898957
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024795500030450057
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.21280011930503e-05
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030491500001517124
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003235440017306246
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020645799850171898
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021316400125215296
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022563499987882096
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00017694100097287446
    STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0009510040017630672
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0018855740017897915
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-IMQ.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-IMQ.optimize] done with 29 nodes in 0.045
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 63 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 63 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 63 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              AZAA                 p_
004 ~ | INITIA float32  2:16x16              AZAA                 p_ | INITIA float32  2:16x16              BAAZ                 p_
005 ~ | INITIA float32  2:16x16              BAAZ                 p_ | INITIA float32  2:16x16              YZBZ                 p_
006 ~ | INITIA float32  2:16x16              YZBZ                 p_ | INITIA float32  2:30x30              KGSP                 sl
007 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
008 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              AABA                 p_
009 ~ | INITIA float32  2:16x16              AABA                 p_ | INITIA float32  2:16x16              BAAA                 p_
010 ~ | INITIA float32  2:16x16              BAAA                 p_ | INITIA float32  2:16x16              ZBCA                 p_
011 ~ | INITIA float32  2:16x16              ZBCA                 p_ | INITIA float32  2:30x30              KGSP                 sl
012 = | INITIA float32  2:32x16              AAAA                 p_ | INITIA float32  2:32x16              AAAA                 p_
013 = | INITIA float32  2:16x128             XFBD                 p_ | INITIA float32  2:16x128             XFBD                 p_
014 = | INITIA float32  2:128x16             ACBA                 p_ | INITIA float32  2:128x16             ACBA                 p_
015 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
016 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
017 - | INITIA int64    1:2                  AAAA                 Sl |
018 - | INITIA int64    1:2                  EEAA                 Sl |
019 - | INITIA int64    1:2                  ABAA                 Sl |
020 = | INITIA float32  2:1024x16            RSYN                 em | INITIA float32  2:1024x16            RSYN                 em
021 = | INITIA float32  2:1024x16            QUFT                 em | INITIA float32  2:1024x16            QUFT                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                ZAAA                 de | INITIA float32  1:128                ZAAA                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 = | RESULT float32  3:1x30x16            JFNV Gather          em | RESULT float32  3:1x30x16            JFNV Gather          em
027 = | RESULT float32  3:1x30x16            OEIP Gather          em | RESULT float32  3:1x30x16            OEIP Gather          em
028 ~ | RESULT float32  3:1x30x16            WKWJ Add             ad | RESULT float32  3:1x30x16            ZBAA SkipLayerNormal _o
029 ~ | RESULT float32  3:1x30x16            ZBAA LayerNormalizat _o | RESULT float32  3:1x30x1             ZAAA SkipLayerNormal un
030 ~ | RESULT float32  3:1x30x16            ATFB MatMul          li | RESULT float32  3:1x30x1             GGFE SkipLayerNormal un
031 ~ | RESULT float32  3:1x30x16            ZEWY MatMul          li | RESULT float32  3:1x30x16            WKWJ SkipLayerNormal ad
032 ~ | RESULT float32  3:1x30x16            IAFR MatMul          li | RESULT float32  3:1x30x16            ATFB MatMul          li
033 ~ | RESULT float32  3:1x16x30            ZZZA Transpose       tr | RESULT float32  3:1x30x16            ZEWY MatMul          li
034 ~ | RESULT float32  3:1x30x30            QWMM MatMul          ma | RESULT float32  3:1x30x30            YFXX FusedMatMul     _o
035 ~ | RESULT float32  3:1x30x30            YFXX Mul             _o | RESULT float32  3:1x30x16            IAFR MatMul          li
036 - | RESULT float32  2:30x30              KGSP Slice           sl |
037 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
038 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
039 = | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x30            IHHH Softmax         so
040 = | RESULT float32  3:1x30x16            DGCH MatMul          ma | RESULT float32  3:1x30x16            DGCH MatMul          ma
041 = | RESULT float32  3:1x30x16            CCAF MatMul          li | RESULT float32  3:1x30x16            CCAF MatMul          li
042 = | RESULT float32  3:1x30x16            UCAT MatMul          li | RESULT float32  3:1x30x16            UCAT MatMul          li
043 ~ | RESULT float32  3:1x30x16            EXDP MatMul          li | RESULT float32  3:1x30x30            AXAX FusedMatMul     _o
044 ~ | RESULT float32  3:1x16x30            OCCX Transpose       tr | RESULT float32  3:1x30x16            EXDP MatMul          li
045 ~ | RESULT float32  3:1x30x30            AODM MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
046 ~ | RESULT float32  3:1x30x30            AXAX Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
047 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IHHH Softmax         so
048 - | RESULT bool     2:30x30              HLZC Equal           eq |
049 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            HBAC MatMul          ma
050 ~ | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x32            LHCK Concat          ca
051 ~ | RESULT float32  3:1x30x16            HBAC MatMul          ma | RESULT float32  3:1x30x16            AAAA MatMul          _o
052 ~ | RESULT float32  3:1x30x32            LHCK Concat          ca | RESULT float32  3:1x30x16            WWVV Add             li
053 ~ | RESULT float32  3:1x30x16            AAAA MatMul          _o | RESULT float32  3:1x30x16            ZBAA SkipLayerNormal _o
054 ~ | RESULT float32  3:1x30x16            WWVV Add             li | RESULT float32  3:1x30x1             YZAZ SkipLayerNormal un
055 ~ | RESULT float32  3:1x30x16            SGRE Add             ad | RESULT float32  3:1x30x1             GGFE SkipLayerNormal un
056 ~ | RESULT float32  3:1x30x16            ZBAA LayerNormalizat _o | RESULT float32  3:1x30x16            SGRE SkipLayerNormal ad
057 = | RESULT float32  3:1x30x128           AZSV MatMul          _o | RESULT float32  3:1x30x128           AZSV MatMul          _o
058 = | RESULT float32  3:1x30x128           XXOS Add             li | RESULT float32  3:1x30x128           XXOS Add             li
059 = | RESULT float32  3:1x30x128           RDDF Relu            re | RESULT float32  3:1x30x128           RDDF Relu            re
060 = | RESULT float32  3:1x30x16            EDHJ MatMul          _o | RESULT float32  3:1x30x16            EDHJ MatMul          _o
061 = | RESULT float32  3:1x30x16            HGKL Add             li | RESULT float32  3:1x30x16            HGKL Add             li
062 = | RESULT float32  3:1x30x16            ZMCQ Add             ou | RESULT float32  3:1x30x16            ZMCQ Add             ou
063 = | OUTPUT float32  3:1x30x16            ZMCQ                 ou | OUTPUT float32  3:1x30x16            ZMCQ                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 1.560 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a model with a test

to_onnx and a model with a test

Gallery generated by Sphinx-Gallery