to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        _B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-4.333163261413574, max=4.399672031402588

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, 0, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, 0, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, 0, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_30' type=int64 shape=(1,) -- array([30])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='init1_s_::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Concat(init7_s1_0, init7_s1_1, axis=0) -> SliceSlicePattern_init7_s1_1_axis
Concat(init7_s1_30, init7_s1_30, axis=0) -> SliceSlicePattern_init7_s1_30_end
Concat(init7_s1_0, init7_s1_0, axis=0) -> SliceSlicePattern_init7_s1_0_start
  Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
    Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1) -> _onx_mul_matmul
      Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
        Softmax(masked_fill, axis=-1) -> softmax
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
        MatMul(softmax, linear_2) -> matmul_1
      MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
    MatMul(linear_3, transpose_1) -> matmul_2
      Mul(matmul_2, init1_s_::RSh1) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_4
  Equal(slice_4, init1_s_2::RSh1) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_1
        MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
                Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.333163261413574, max=4.399672031402588
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-4.333163261413574, max=4.399672031402588
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16__cst2init' type=float32 shape=(16,)             -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s16_2_cst2init' type=float32 shape=(16,)            -- GraphBuilderPatternOptimization.make_initializer.0
init: name='init1_s1__cst2init' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_::RSh1_cst2init' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='init1_s_2::RSh1_cst2init' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilderPatternOptimization.make_initializer.1/Small
init: name='SliceSlicePattern_init7_s1_0_start_cst2init' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end_cst2init' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis_cst2init' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilderPatternOptimization.make_initializer.1/Shape
init: name='bias_cst2init' type=float32 shape=(16,)                   -- GraphBuilderPatternOptimization.make_initializer.0
init: name='bias2_cst2init' type=float32 shape=(16,)                  -- GraphBuilderPatternOptimization.make_initializer.0
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
    LayerNormalization(embedding, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
      MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
      Mul(matmul, init1_s_::RSh1_cst2init) -> _onx_mul_matmul
MatMul(norm_1, weight::T103) -> value
Slice(mask, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_2
  Equal(slice_2, init1_s_2::RSh1_cst2init) -> eq
    Where(eq, init1_s1__cst2init, _onx_mul_matmul) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
    Mul(matmul2, init1_s_::RSh1_cst2init) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Slice(mask2, SliceSlicePattern_init7_s1_0_start_cst2init, SliceSlicePattern_init7_s1_30_end_cst2init, SliceSlicePattern_init7_s1_1_axis_cst2init) -> slice_22
  Equal(slice_22, init1_s_2::RSh1_cst2init) -> eq2
    Where(eq2, init1_s1__cst2init, _onx_mul_matmul2) -> masked_fill2
      Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias_cst2init) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16__cst2init, init1_s16_2_cst2init, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
          Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, weight::T1023) -> _onx_matmul_relu
                Add(_onx_matmul_relu, bias2_cst2init) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True, verbose=2),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-HLO.optimize] start with 73 nodes
[GraphBuilder-HLO.optimize] #patterns=110
[GraphBuilder-HLO.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-HLO.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-HLO.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-HLO.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-HLO.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-HLO.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-HLO.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-HLO.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-HLO.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-HLO.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-HLO.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-HLO.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern   9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  46/110 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  47/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  48/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  49/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  50/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  51/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  52/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  53/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  54/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  55/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  56/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  57/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  58/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  59/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  60/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  61/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  62/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  63/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  64/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  65/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  66/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  67/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  68/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  69/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  70/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  71/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  72/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  73/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  74/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  75/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  76/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  77/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  78/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  79/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  80/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  81/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  82/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  83/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  84/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  85/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  86/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  87/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  88/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  89/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  90/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  91/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  92/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  93/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  94/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  95/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  96/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  97/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  98/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern  99/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 100/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 101/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 102/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 103/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 104/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 105/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 106/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 107/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 108/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 109/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-HLO.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-HLO.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-HLO.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-HLO.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-HLO.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-HLO.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-HLO.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-HLO.optimize] increase priority to 1
[GraphBuilderPatternOptimization-HLO.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-HLO.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-HLO.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-HLO.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-HLO.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.005 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization-HLO.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-HLO.optimize] increase priority to 2
[GraphBuilderPatternOptimization-HLO.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-HLO.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-HLO.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-HLO.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-HLO.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-HLO.optimize] increase priority to 3
[GraphBuilderPatternOptimization-HLO.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-HLO.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-HLO.optimize] done after 8 iterations with 29 nodes in 0.054
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.0001723969980957918
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.000673380996886408
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00023979999969014898
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004525479998847004
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0003544610008248128
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0013156780041754246
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011354800153640099
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=9.716000931803137e-06
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0008246120050898753
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007010170011199079
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007421599984809291
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0007541799932369031
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.0009021700025186874
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002740159943641629
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023372700525214896
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017684999693301506
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017152099462691694
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004436009985511191
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00023066400171956047
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018826699670171365
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004079410027770791
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002456150032230653
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015757298751850612
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002946699969470501
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002486100020178128
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002692319903871976
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022198700025910512
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018808200547937304
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015097799769137055
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001672569997026585
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.977200195658952e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00022119199638837017
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001333680011157412
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001511329974164255
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00021671600552508608
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014730300244991668
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001703389971225988
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003405250026844442
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026911100576398894
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000207890996534843
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018999300664290786
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018627900135470554
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.528699254384264e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0002006879985856358
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.0004500110007938929
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.68450008763466e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.222200267482549e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00020293600027798675
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00264904199866578
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0029831009996996727
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=8.497001545038074e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016557799972360954
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0021669380039384123
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003258099968661554
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020132799909333698
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0025238369998987764
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.261599941761233e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0007021879937383346
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019630399765446782
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016829699961817823
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00041063000026042573
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002414660011709202
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021530300000449643
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026854799943976104
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001649089972488582
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00026000799698522314
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024280999423353933
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017760900300345384
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005246460059424862
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=3.312199987703934e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003510160058795009
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023547700038761832
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003237529963371344
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022316700051305816
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031471100010094233
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015838200124562718
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0005291520028549712
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0009448159980820492
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018283100507687777
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018946099589811638
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022426699797506444
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005329160012479406
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004506580007728189
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019979700664407574
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004055519966641441
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002698799944482744
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004581969988066703
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002892280062951613
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023894499827292748
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003005779981322121
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024243699954240583
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022982299924478866
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000219395002204692
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031030700120027177
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00019556899496819824
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021693200324079953
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015458200141438283
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016941800276981667
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017050000315066427
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0037985969938745257
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016551199951209128
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016104599490063265
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003632650004874449
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024043599842116237
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000257101004535798
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002237959997728467
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018160500258090906
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021849099721293896
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003226369990443345
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00039116699917940423
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006148589964141138
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032628500048303977
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003146920025756117
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00011871900278492831
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0003131949997623451
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005266170082904864
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00047179899411275983
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00028350500724627636
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00027891799618373625
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003492720024951268
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025233099586330354
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002409049920970574
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002463310011080466
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=4.880599590251222e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0019299200030218344
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0021111540008860175
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-HLO.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-HLO.optimize] done with 29 nodes in 0.064
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='weight::T10' type=float32 shape=(16, 16)                  -- GraphBuilder.make_local_function/from(weight::T10)
init: name='weight::T102' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T102)
init: name='weight::T103' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T103)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='weight::T104' type=float32 shape=(16, 16)                 -- GraphBuilder.make_local_function/from(weight::T104)
init: name='weight::T1022' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1022)
init: name='weight::T1032' type=float32 shape=(16, 16)                -- GraphBuilder.make_local_function/from(weight::T1032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='weight::T105' type=float32 shape=(32, 16)                 -- GraphBuilder.make_local_function/from(weight::T105)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='weight::T106' type=float32 shape=(16, 128)                -- GraphBuilder.make_local_function/from(weight::T106)
init: name='weight::T1023' type=float32 shape=(128, 16)               -- GraphBuilder.make_local_function/from(weight::T1023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, init1_s_2::RSh12) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, weight::T10) -> query
MatMul(norm_1, weight::T102) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_2, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, weight::T103) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, weight::T104) -> query2
MatMul(norm_1, weight::T1022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul2
MatMul(norm_1, weight::T1032) -> value2
Equal(slice_4, init1_s_2::RSh12) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul2) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, weight::T105) -> _onx_matmul_cat
        Add(_onx_matmul_cat, bias) -> attention
    SkipLayerNormalization[com.microsoft](attention, embedding, init1_s16_3, init1_s16_22, epsilon=0.00) -> norm_2, unused3, unused4, add_1
      MatMul(norm_2, weight::T106) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_1
          Relu(linear_1) -> relu
            MatMul(relu, weight::T1023) -> _onx_matmul_relu
              Add(_onx_matmul_relu, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-EWM.optimize] start with 73 nodes
[GraphBuilder-EWM.optimize] #patterns=110
[GraphBuilder-EWM.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-EWM.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-EWM.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-EWM.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-EWM.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-EWM.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-EWM.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-EWM.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-EWM.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-EWM.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-EWM.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-EWM.optimize] start with 53 nodes, 28 initializers, 110 patterns, priorities=[0, 1, 2, 3], max_iter=212
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   1/110 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   2/110 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   3/110 - P0 - CastCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   4/110 - P0 - CastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   5/110 - P0 - ConcatGatherPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   6/110 - P0 - ConcatReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   7/110 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   8/110 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern   9/110 - P0 - FunctionAttentionPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  10/110 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  11/110 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  12/110 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  13/110 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  14/110 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  15/110 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  16/110 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  17/110 - P0 - SameChildrenFromInputPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  18/110 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  19/110 - P0 - ShapeBasedEditDistanceReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  20/110 - P0 - ShapeBasedIdentityPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  21/110 - P0 - ShapeBasedReshapeIsSqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  22/110 - P0 - ShapeBasedSameChildrenPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  23/110 - P0 - ShapeBasedShapeShapeAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  24/110 - P0 - ShapeBasedStaticExpandPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  25/110 - P0 - ShapedBasedReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  26/110 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  27/110 - P0 - SqueezeAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  28/110 - P0 - SqueezeBinaryUnsqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  29/110 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  30/110 - P0 - StaticConcatReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  31/110 - P0 - SwapExpandReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  32/110 - P0 - SwapUnaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  33/110 - P0 - SwapUnsqueezeTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  34/110 - P0 - TransposeGatherPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  35/110 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  36/110 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  37/110 - P0 - UnsqueezeOrSqueezeReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  38/110 - P0 - UnsqueezeReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  39/110 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  40/110 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  41/110 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  42/110 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  43/110 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  44/110 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  45/110 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  46/110 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  47/110 - P1 - ConcatEmptyPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  48/110 - P1 - ConcatTwiceUnaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  49/110 - P1 - ConstantToInitializerPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  50/110 - P1 - ContribRotaryEmbedding3DPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  51/110 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  52/110 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  53/110 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  54/110 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  55/110 - P1 - FunctionCausalMaskMulAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  56/110 - P1 - FunctionCausalMaskPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  57/110 - P1 - FunctionCosSinCachePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  58/110 - P1 - FunctionHalfRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  59/110 - P1 - GathersSplitPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  60/110 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  61/110 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  62/110 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  63/110 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  64/110 - P1 - MissingCosSinPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  65/110 - P1 - MissingRangePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  66/110 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  67/110 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  68/110 - P1 - MultiHeadAttention3DPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  69/110 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  70/110 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  71/110 - P1 - RMSNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  72/110 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  73/110 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  74/110 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  75/110 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  76/110 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  77/110 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  78/110 - P1 - RotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  79/110 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  80/110 - P1 - ShapeBasedConcatExpandPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  81/110 - P1 - ShapeBasedExpandBroadcastMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  82/110 - P1 - ShapeBasedExpandBroadcastPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  83/110 - P1 - ShapeBasedExpandCastWhereSwapPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  84/110 - P1 - ShapeBasedExpandSwapPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  85/110 - P1 - ShapeBasedMatMulToMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  86/110 - P1 - SimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  87/110 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  88/110 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  89/110 - P1 - SkipSimplifiedLayerNormalizationMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  90/110 - P1 - SkipSimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  91/110 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  92/110 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  93/110 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  94/110 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  95/110 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  96/110 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  97/110 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  98/110 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern  99/110 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 100/110 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 101/110 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 102/110 - P2 - ContribRotaryEmbeddingPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 103/110 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 104/110 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 105/110 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 106/110 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 107/110 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 108/110 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 109/110 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-EWM.optimize] use pattern 110/110 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-EWM.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-EWM.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-EWM.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilder-EWM.remove_unused] remove_initializer 1:5/28:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:6/28:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:7/28:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-EWM.remove_unused] remove_initializer 4:8/28:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilderPatternOptimization-EWM.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-EWM.optimize] increase priority to 1
[GraphBuilderPatternOptimization-EWM.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-EWM.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.005 | max_time=IdentityPattern:0.000
[GraphBuilder-EWM.remove_unused] remove_initializer 1:5/26:init7_s1_-1:int64[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:6/26:init1_s1_:float32[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 3:7/26:init1_s1_2:float32[(1,)]
[GraphBuilderPatternOptimization-EWM.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-EWM.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-EWM.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-EWM.optimize] increase priority to 2
[GraphBuilderPatternOptimization-EWM.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-EWM.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilder-EWM.remove_unused] remove_initializer 1:9/23:init1_s_::RSh1:float32[(1,)]
[GraphBuilder-EWM.remove_unused] remove_initializer 2:15/23:init1_s_::RSh12:float32[(1,)]
[GraphBuilderPatternOptimization-EWM.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-EWM.optimize] increase priority to 3
[GraphBuilderPatternOptimization-EWM.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-EWM.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-EWM.optimize] done after 8 iterations with 29 nodes in 0.045
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00017572700016899034
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0004853729988099076
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00023937799414852634
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0004224229996907525
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00016526099716429599
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0010593710067041684
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00010789400039357133
    STAT check_pattern_A10 +0 -0 #it=4 maxmatch=0 i=0 - time=8.25899769552052e-06
    STAT check_pattern_A20 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0006228849924809765
    STAT check_pattern_BD0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005360539980756585
    STAT check_pattern_BI0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005590109976765234
    STAT check_pattern_BU0 +0 -0 #it=8 maxmatch=0 i=0 - time=0.0005760540043411311
    STAT insert_and_remove_nodes +0 -0 #it=0 maxmatch=0 i=0 - time=0.00072660599835217
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00023676199634792283
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00020293999841669574
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013246700109448284
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002626610003062524
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004166899998381268
    STAT match_CastCastPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00020839499484281987
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015502799942623824
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0003525500032992568
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002206230019510258
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013575199409388006
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002454509995004628
    STAT match_ConcatEmptyPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00018761200044536963
    STAT match_ConcatGatherPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00026909499865723774
    STAT match_ConcatReshapePattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002003980043809861
    STAT match_ConcatTwiceUnaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00017381199722876772
    STAT match_ConstantToInitializerPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00013258099716040306
    STAT match_ContribRotaryEmbedding3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013269899864098988
    STAT match_ContribRotaryEmbeddingPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.581199704669416e-05
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001918070083775092
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00011614200775511563
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012782899648300372
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.000193311003386043
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012491100278566591
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013604500054498203
    STAT match_FunctionAttentionPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00025412800459889695
    STAT match_FunctionCausalMaskMulAddPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020083299750695005
    STAT match_FunctionCausalMaskPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013776700143353082
    STAT match_FunctionCosSinCachePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016289999621221796
    STAT match_FunctionHalfRotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001309809995291289
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.145099541754462e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.0001370050013065338
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00032459900103276595
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.578800144372508e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.085700053721666e-05
    STAT match_GathersSplitPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001712220000626985
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0022345819961628877
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0025460659962845966
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.695001841057092e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013169399971957318
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0019215859974792693
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0002948989967990201
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001595130015630275
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0023374550000880845
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.9786001656902954e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0005574090027948841
    STAT match_MissingCosSinPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014472100156126544
    STAT match_MissingRangePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013496300016413443
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003049280021514278
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017915699936565943
    STAT match_MultiHeadAttention3DPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017910299720824696
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000239348995819455
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013305900210980326
    STAT match_RMSNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013289100388647057
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017903599655255675
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013284999658935703
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003690820012707263
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.931899896590039e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028197900246595964
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021704999744542874
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00025234700297005475
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019502999930409715
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001841830016928725
    STAT match_RotaryEmbeddingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012774599963449873
    STAT match_SameChildrenFromInputPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0004596960025082808
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007208260030893143
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013939599739387631
    STAT match_ShapeBasedConcatExpandPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013823499830323271
    STAT match_ShapeBasedEditDistanceReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001885819910967257
    STAT match_ShapeBasedExpandBroadcastMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00042165600825683214
    STAT match_ShapeBasedExpandBroadcastPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034922899794764817
    STAT match_ShapeBasedExpandCastWhereSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016093799786176533
    STAT match_ShapeBasedExpandSwapPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00033534799149492756
    STAT match_ShapeBasedIdentityPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021886500690015964
    STAT match_ShapeBasedMatMulToMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003706440002133604
    STAT match_ShapeBasedReshapeIsSqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00023459699514205568
    STAT match_ShapeBasedSameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020034900080645457
    STAT match_ShapeBasedShapeShapeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.000267596999037778
    STAT match_ShapeBasedStaticExpandPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019412000619922765
    STAT match_ShapedBasedReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019651100228657015
    STAT match_SimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00017006600683089346
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002447639999445528
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.000168531001691008
    STAT match_SkipSimplifiedLayerNormalizationMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014786100291530602
    STAT match_SkipSimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012568600504891947
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013235599908512086
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016139199942699634
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0032920120065682568
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015577500380459242
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001337530011369381
    STAT match_SqueezeAddPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033944000460905954
    STAT match_SqueezeBinaryUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019519499619491398
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020287299776100554
    STAT match_StaticConcatReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00019088300177827477
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015230500139296055
    STAT match_SwapExpandReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001900570023281034
    STAT match_SwapUnaryPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002686159969016444
    STAT match_SwapUnsqueezeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021872000070288777
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003819379890046548
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019452800552244298
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019614600023487583
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=8.107699977699667e-05
    STAT match_TransposeGatherPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022033500135876238
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035078099608654156
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00031080899861990474
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00021419699987745844
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00024124601259245537
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002268880016345065
    STAT match_UnsqueezeOrSqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0001956129926838912
    STAT match_UnsqueezeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002117060030286666
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00018981700122822076
    STAT remove_duplicated_shape +0 -0 #it=8 maxmatch=0 i=0 - time=4.4573997001862153e-05
    STAT remove_identity_nodes +9 -15 #it=8 maxmatch=0 i=0 - time=0.0016913559920794796
    STAT remove_unused +0 -0 #it=8 maxmatch=0 i=0 - time=0.0026500370004214346
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  21 x 1t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 21 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   4 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   3 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-EWM.remove_unused] remove_initializer 1:15/21:init1_s_2::RSh12:float32[(1,)]
[GraphBuilder-EWM.optimize] done with 29 nodes in 0.052
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='p_decoder_attention_attention_0_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='p_decoder_attention_attention_0_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='p_decoder_attention_attention_0_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_attention_1_query_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='p_decoder_attention_attention_1_key_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='p_decoder_attention_attention_1_value_weight::T10' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='p_decoder_attention_linear_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='p_decoder_feed_forward_linear_1_weight::T10' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='p_decoder_feed_forward_linear_2_weight::T10' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, init1_s_2::RSh1) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add, unused, unused2, add
    MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_query_weight::T10) -> linear
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_key_weight::T10) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul
  Where(eq, init1_s1_3, _onx_mul_matmul) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add, p_decoder_attention_attention_0_value_weight::T10) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_query_weight::T10) -> linear_3
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_key_weight::T10) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_2
MatMul(_onx_div_sub_add, p_decoder_attention_attention_1_value_weight::T10) -> linear_5
Equal(slice_4, init1_s_2::RSh1) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_2) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, p_decoder_attention_linear_weight::T10) -> _onx_matmul_cat
        Add(_onx_matmul_cat, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_1, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_1, p_decoder_feed_forward_linear_1_weight::T10) -> _onx_matmul_layer_norm_1
        Add(_onx_matmul_layer_norm_1, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, p_decoder_feed_forward_linear_2_weight::T10) -> _onx_matmul_relu
              Add(_onx_matmul_relu, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 66 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 66 results (first model)
[compare_onnx_execution] got 57 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 66 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA int64    1:1                  BAAA                 in |
004 - | INITIA int64    1:1                  AAAA                 in |
005 - | INITIA int64    1:1                  EAAA                 in |
006 ~ | INITIA float32  1:1                  ?AAA                 in | INITIA float32  2:16x16              AAAB                 p_
007 ~ | INITIA float32  2:16x16              AAAB                 p_ | INITIA float32  2:16x16              BAAB                 p_
008 ~ | INITIA float32  2:16x16              BAAB                 p_ | INITIA float32  2:16x16              AAAA                 p_
009 ~ | INITIA float32  2:16x16              AAAA                 p_ | INITIA float32  2:30x30              KGSP                 sl
010 = | INITIA float32  1:1                  AAAA                 in | INITIA float32  1:1                  AAAA                 in
011 ~ | INITIA float32  1:1                  AAAA                 in | INITIA float32  2:16x16              BBAA                 p_
012 ~ | INITIA float32  2:16x16              BBAA                 p_ | INITIA float32  2:16x16              AAZA                 p_
013 ~ | INITIA float32  2:16x16              AAZA                 p_ | INITIA float32  2:16x16              ZAAA                 p_
014 ~ | INITIA float32  2:16x16              ZAAA                 p_ | INITIA float32  2:30x30              KGSP                 sl
015 = | INITIA float32  2:32x16              AZBC                 p_ | INITIA float32  2:32x16              AZBC                 p_
016 = | INITIA float32  2:16x128             CZEW                 p_ | INITIA float32  2:16x128             CZEW                 p_
017 = | INITIA float32  2:128x16             AAZA                 p_ | INITIA float32  2:128x16             AAZA                 p_
018 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
019 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
020 = | INITIA float32  2:1024x16            PHTW                 em | INITIA float32  2:1024x16            PHTW                 em
021 = | INITIA float32  2:1024x16            ILTN                 em | INITIA float32  2:1024x16            ILTN                 em
022 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
023 = | INITIA float32  1:128                AAZA                 de | INITIA float32  1:128                AAZA                 de
024 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
025 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
026 - | RESULT int64    1:2                  ABAA Concat          Sl |
027 - | RESULT int64    1:2                  EEAA Concat          Sl |
028 - | RESULT int64    1:2                  AAAA Concat          Sl |
029 = | RESULT float32  3:1x30x16            VPRM Gather          em | RESULT float32  3:1x30x16            VPRM Gather          em
030 = | RESULT float32  3:1x30x16            ILJE Gather          em | RESULT float32  3:1x30x16            ILJE Gather          em
031 ~ | RESULT float32  3:1x30x16            DAAQ Add             ad | RESULT float32  3:1x30x16            ZBXD SkipLayerNormal _o
032 ~ | RESULT float32  3:1x30x16            ZBXD LayerNormalizat _o | RESULT float32  3:1x30x1             ZAZA SkipLayerNormal un
033 ~ | RESULT float32  3:1x30x16            OTTE MatMul          li | RESULT float32  3:1x30x1             GFGE SkipLayerNormal un
034 ~ | RESULT float32  3:1x30x16            VDVA MatMul          li | RESULT float32  3:1x30x16            DAAQ SkipLayerNormal ad
035 ~ | RESULT float32  3:1x30x16            VVDX MatMul          li | RESULT float32  3:1x30x16            OTTE MatMul          li
036 ~ | RESULT float32  3:1x16x30            BBOA Transpose       tr | RESULT float32  3:1x30x16            VDVA MatMul          li
037 ~ | RESULT float32  3:1x30x30            XIBS MatMul          ma | RESULT float32  3:1x30x30            AWAS FusedMatMul     _o
038 ~ | RESULT float32  3:1x30x30            AWAS Mul             _o | RESULT float32  3:1x30x16            VVDX MatMul          li
039 - | RESULT float32  2:30x30              KGSP Slice           sl |
040 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
041 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
042 = | RESULT float32  3:1x30x30            HHHH Softmax         so | RESULT float32  3:1x30x30            HHHH Softmax         so
043 = | RESULT float32  3:1x30x16            UXWX MatMul          ma | RESULT float32  3:1x30x16            UXWX MatMul          ma
044 = | RESULT float32  3:1x30x16            GXXA MatMul          li | RESULT float32  3:1x30x16            GXXA MatMul          li
045 = | RESULT float32  3:1x30x16            IACC MatMul          li | RESULT float32  3:1x30x16            IACC MatMul          li
046 ~ | RESULT float32  3:1x30x16            XZZX MatMul          li | RESULT float32  3:1x30x30            HUEC FusedMatMul     _o
047 ~ | RESULT float32  3:1x16x30            BDYK Transpose       tr | RESULT float32  3:1x30x16            XZZX MatMul          li
048 ~ | RESULT float32  3:1x30x30            DZSJ MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
049 ~ | RESULT float32  3:1x30x30            HUEC Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
050 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IHHH Softmax         so
051 - | RESULT bool     2:30x30              HLZC Equal           eq |
052 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            AZZX MatMul          ma
053 ~ | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x32            VVVU Concat          ca
054 ~ | RESULT float32  3:1x30x16            AZZX MatMul          ma | RESULT float32  3:1x30x16            ZWWX MatMul          _o
055 ~ | RESULT float32  3:1x30x32            VVVU Concat          ca | RESULT float32  3:1x30x16            AWXX Add             li
056 ~ | RESULT float32  3:1x30x16            ZWWX MatMul          _o | RESULT float32  3:1x30x16            ZBXD SkipLayerNormal _o
057 ~ | RESULT float32  3:1x30x16            AWXX Add             li | RESULT float32  3:1x30x1             ZAYA SkipLayerNormal un
058 ~ | RESULT float32  3:1x30x16            DXWN Add             ad | RESULT float32  3:1x30x1             GFGE SkipLayerNormal un
059 ~ | RESULT float32  3:1x30x16            ZBXD LayerNormalizat _o | RESULT float32  3:1x30x16            DXWN SkipLayerNormal ad
060 = | RESULT float32  3:1x30x128           GCKF MatMul          _o | RESULT float32  3:1x30x128           GCKF MatMul          _o
061 = | RESULT float32  3:1x30x128           BWFY Add             li | RESULT float32  3:1x30x128           BWFY Add             li
062 = | RESULT float32  3:1x30x128           AXWJ Relu            re | RESULT float32  3:1x30x128           AXWJ Relu            re
063 = | RESULT float32  3:1x30x16            BCAY MatMul          _o | RESULT float32  3:1x30x16            BCAY MatMul          _o
064 = | RESULT float32  3:1x30x16            ZAYV Add             li | RESULT float32  3:1x30x16            ZAYV Add             li
065 = | RESULT float32  3:1x30x16            BWTH Add             ou | RESULT float32  3:1x30x16            BWTH Add             ou
066 = | OUTPUT float32  3:1x30x16            BWTH                 ou | OUTPUT float32  3:1x30x16            BWTH                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 2.434 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and a custom operator registered with a function

to_onnx and a custom operator registered with a function

Gallery generated by Sphinx-Gallery