to_onnx and submodules from LLMs

Big models are hard to read once converted into onnx. Let’s see how to improve their readibility. The code is inspired from LLM from scratch with Pytorch.

A simple LLM

All comments were removed from the code to make it less verbose. A few fixes were applied to the original code.

import onnx
from onnx.inliner import inline_local_functions
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_array_api.reference import compare_onnx_execution
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
import torch
from onnxruntime import InferenceSession
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions


class Embedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.pe = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        word_emb = self.embedding(x)
        word_pe = self.pe(x)
        return word_emb + word_pe


class AttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, context_size: int):
        super().__init__()
        self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)

        ones = torch.ones(size=[context_size, context_size], dtype=torch.float)
        self.register_buffer(name="mask", tensor=torch.tril(input=ones))

    def forward(self, x):
        B, T, C = x.size()

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        qk = query @ key.transpose(-2, -1) * C**-0.5
        attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention = torch.nn.functional.softmax(input=attention, dim=-1)

        out = attention @ value
        return out


class MultiAttentionBlock(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int):
        super().__init__()
        self.attention = torch.nn.ModuleList(
            modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
        )
        self.linear = torch.nn.Linear(
            in_features=embedding_dim * num_heads, out_features=embedding_dim
        )

    def forward(self, x):
        out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
        x = self.linear(out)
        return x


class FeedForward(torch.nn.Module):

    def __init__(self, embedding_dim: int, ff_dim: int):
        super().__init__()
        self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
        self.relu = torch.nn.ReLU()
        self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        return x


class DecoderLayer(torch.nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, context_size: int, ff_dim: int):
        super().__init__()
        self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
        self.feed_forward = FeedForward(embedding_dim, ff_dim)
        self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
        self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, x):
        x_norm = self.norm_1(x)
        attention = self.attention(x_norm)
        attention = attention + x

        attention_norm = self.norm_2(attention)
        ff = self.feed_forward(attention_norm)
        ff = ff + attention

        return ff


class LLM(torch.nn.Module):

    def __init__(
        self,
        vocab_size: int = 1024,
        embedding_dim: int = 16,
        num_heads: int = 2,
        context_size: int = 256,
        ff_dim: int = 128,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        y = self.decoder(x)
        return y


llm = LLM()
dim = (1, 30)
input_ids = torch.randint(0, 1024, dim).to(torch.int64)
y = llm(input_ids)

print(f"output: shape={y.shape}, min={y.min()}, max={y.max()}")
output: shape=torch.Size([1, 30, 16]), min=-3.806222438812256, max=4.5326056480407715

First conversion to ONNX

The conversion relies on torch.export.export(). which gives:

graph():
    %p_embedding_embedding_weight : [num_users=1] = placeholder[target=p_embedding_embedding_weight]
    %p_embedding_pe_weight : [num_users=1] = placeholder[target=p_embedding_pe_weight]
    %p_decoder_attention_attention_0_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_query_weight]
    %p_decoder_attention_attention_0_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_key_weight]
    %p_decoder_attention_attention_0_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_0_value_weight]
    %p_decoder_attention_attention_1_query_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_query_weight]
    %p_decoder_attention_attention_1_key_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_key_weight]
    %p_decoder_attention_attention_1_value_weight : [num_users=1] = placeholder[target=p_decoder_attention_attention_1_value_weight]
    %p_decoder_attention_linear_weight : [num_users=1] = placeholder[target=p_decoder_attention_linear_weight]
    %p_decoder_attention_linear_bias : [num_users=1] = placeholder[target=p_decoder_attention_linear_bias]
    %p_decoder_feed_forward_linear_1_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_weight]
    %p_decoder_feed_forward_linear_1_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_1_bias]
    %p_decoder_feed_forward_linear_2_weight : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_weight]
    %p_decoder_feed_forward_linear_2_bias : [num_users=1] = placeholder[target=p_decoder_feed_forward_linear_2_bias]
    %p_decoder_norm_1_weight : [num_users=1] = placeholder[target=p_decoder_norm_1_weight]
    %p_decoder_norm_1_bias : [num_users=1] = placeholder[target=p_decoder_norm_1_bias]
    %p_decoder_norm_2_weight : [num_users=1] = placeholder[target=p_decoder_norm_2_weight]
    %p_decoder_norm_2_bias : [num_users=1] = placeholder[target=p_decoder_norm_2_bias]
    %b_decoder_attention_attention_0_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_0_mask]
    %b_decoder_attention_attention_1_mask : [num_users=1] = placeholder[target=b_decoder_attention_attention_1_mask]
    %input_ids : [num_users=2] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_embedding_weight, %input_ids), kwargs = {})
    %embedding_1 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_embedding_pe_weight, %input_ids), kwargs = {})
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%embedding, %embedding_1), kwargs = {})
    %layer_norm : [num_users=6] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add, [16], %p_decoder_norm_1_weight, %p_decoder_norm_1_bias), kwargs = {})
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_query_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_key_weight), kwargs = {})
    %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_0_value_weight), kwargs = {})
    %transpose : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_1, -2, -1), kwargs = {})
    %matmul : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear, %transpose), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul, 0.25), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_0_mask, 0, None, 30), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, 30), kwargs = {})
    %eq : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_2, 0), kwargs = {})
    %masked_fill : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul, %eq, -inf), kwargs = {})
    %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill, -1), kwargs = {})
    %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax, %linear_2), kwargs = {})
    %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_query_weight), kwargs = {})
    %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_key_weight), kwargs = {})
    %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm, %p_decoder_attention_attention_1_value_weight), kwargs = {})
    %transpose_1 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%linear_4, -2, -1), kwargs = {})
    %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%linear_3, %transpose_1), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_2, 0.25), kwargs = {})
    %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%b_decoder_attention_attention_1_mask, 0, None, 30), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_3, 1, None, 30), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%slice_4, 0), kwargs = {})
    %masked_fill_1 : [num_users=1] = call_function[target=torch.ops.aten.masked_fill.Scalar](args = (%mul_1, %eq_1, -inf), kwargs = {})
    %softmax_1 : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%masked_fill_1, -1), kwargs = {})
    %matmul_3 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%softmax_1, %linear_5), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%matmul_1, %matmul_3], -1), kwargs = {})
    %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%cat, %p_decoder_attention_linear_weight, %p_decoder_attention_linear_bias), kwargs = {})
    %add_1 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_6, %add), kwargs = {})
    %layer_norm_1 : [num_users=1] = call_function[target=torch.ops.aten.layer_norm.default](args = (%add_1, [16], %p_decoder_norm_2_weight, %p_decoder_norm_2_bias), kwargs = {})
    %linear_7 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%layer_norm_1, %p_decoder_feed_forward_linear_1_weight, %p_decoder_feed_forward_linear_1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear_7,), kwargs = {})
    %linear_8 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %p_decoder_feed_forward_linear_2_weight, %p_decoder_feed_forward_linear_2_bias), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%linear_8, %add_1), kwargs = {})
    return (add_2,)

Then function to_onnx converts it into ONNX.

onx = to_onnx(llm, (input_ids,))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='b_decoder_attention_attention_0_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='b_decoder_attention_attention_1_mask' type=float32 shape=(256, 256)-- DynamoInterpret.placeholder.0
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='_reshape_init1_s_0' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='_reshape_init1_s_02' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='SliceSlicePattern_init7_s1_0_start' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilder.constant_folding.from/fold(init7_s1_0)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilder.constant_folding.from/fold(init7_s1_30)##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_0_start2' type=int64 shape=(2,) -- array([0, 0])-- GraphBuilder.constant_folding.from/fold(init7_s1_0)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_30_end2' type=int64 shape=(2,) -- array([30, 30])-- GraphBuilder.constant_folding.from/fold(init7_s1_30)##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='SliceSlicePattern_init7_s1_1_axis2' type=int64 shape=(2,) -- array([0, 1])-- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1)##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  Add(embedding, embedding_1) -> add
    LayerNormalization(add, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add00
      MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
  Transpose(linear_1, perm=[0,2,1]) -> transpose
    MatMul(linear, transpose) -> matmul
      Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
Slice(b_decoder_attention_attention_0_mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, _reshape_init1_s_20) -> eq
    Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
      Softmax(masked_fill, axis=-1) -> softmax
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
  Transpose(linear_4, perm=[0,2,1]) -> transpose_1
  MatMul(linear_3, transpose_1) -> matmul_2
    Mul(matmul_2, _reshape_init1_s_02) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Slice(b_decoder_attention_attention_1_mask, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_4
  Equal(slice_4, _reshape_init1_s_202) -> eq_1
    Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
      Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, _onx_transpose_p_decoder_attention_linear_weight0) -> _onx_matmul_cat0
        Add(_onx_matmul_cat0, decoder.attention.linear.bias) -> linear_6
    Add(linear_6, add) -> add_1
      LayerNormalization(add_1, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_100
        MatMul(_onx_div_sub_add_100, _onx_transpose_p_decoder_feed_forward_linear_1_weight0) -> _onx_matmul_layer_norm_10
          Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_7
            Relu(linear_7) -> relu
              MatMul(relu, _onx_transpose_p_decoder_feed_forward_linear_2_weight0) -> _onx_matmul_relu0
                Add(_onx_matmul_relu0, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Let’s check there is no discrepancy.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.806222438812256, max=4.5326056480407715
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx, "plot_exporter_recipes_c_modules.inlined.onnx")

ONNX with submodules

Let’s produce an ONNX model with submodules. Function to_onnx is calling the function torch.export.unflatten.unflatten() under the hood. The fx graph looks like the following.

graph():
    %input_ids : [num_users=1] = placeholder[target=input_ids]
    %embedding : [num_users=1] = call_module[target=embedding](args = (%input_ids,), kwargs = {})
    %decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})
    return (decoder,)

The exported graph looks simpler and shows something like:

%decoder : [num_users=1] = call_module[target=decoder](args = (%embedding,), kwargs = {})

It preserves the hierarchy but it does not necessarily preserves the signatures of the initial modules. That’s was not one of our goals. The tricky part is module called (embedding) is not an instance Embedding but an instance of InterpreterModule and contains the fx nodes contributing to the submodule and coming from the previous graph.

Now the ONNX graph.

onx_module = to_onnx(llm, (input_ids,), export_modules_as_functions=True)
print(pretty_onnx(onx_module))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16)       -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='_onx_transpose_weight05' type=float32 shape=(32, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight05)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='_onx_transpose_weight06' type=float32 shape=(16, 128)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight06)
init: name='_onx_transpose_weight023' type=float32 shape=(128, 16)    -- GraphBuilder.make_local_function/from(_onx_transpose_weight023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, _onx_transpose_weight0) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> _reshape_init1_s_0
Constant(value=[0.0]) -> _reshape_init1_s_20
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, _reshape_init1_s_20) -> eq
MatMul(norm_1, _onx_transpose_weight02) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
  Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, _onx_transpose_weight03) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> _reshape_init1_s_02
Constant(value=[0.0]) -> _reshape_init1_s_202
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, _reshape_init1_s_202) -> eq2
MatMul(norm_1, _onx_transpose_weight04) -> query2
MatMul(norm_1, _onx_transpose_weight022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, _reshape_init1_s_02) -> _onx_mul_matmul02
  Where(eq2, init1_s1_2, _onx_mul_matmul02) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, _onx_transpose_weight032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, _onx_transpose_weight05) -> _onx_matmul_cat0
Constant(value=[0.0656113...) -> bias
  Add(_onx_matmul_cat0, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, _onx_transpose_weight06) -> _onx_matmul_layer_norm_10
      Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, _onx_transpose_weight023) -> _onx_matmul_relu0
Constant(value=[-0.025308...) -> bias2
  Add(_onx_matmul_relu0, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

We check again there is no new discrepancies.

sess = InferenceSession(onx_module.SerializeToString(), providers=["CPUExecutionProvider"])
feeds = dict(input_ids=input_ids.numpy())
got = sess.run(None, feeds)[0]

diff = max_diff(y, got)
print(f"output: shape={got.shape}, min={got.min()}, max={got.max()}")
print(f"max discrepancy={diff['abs']}")
output: shape=(1, 30, 16), min=-3.806222438812256, max=4.5326056480407715
max discrepancy=4.76837158203125e-07

Let’s save the ONNX model.

onnx.save(onx_module, "plot_exporter_recipes_c_modules.module.onnx")

And visually.

plot exporter recipes c modules

Inlining

The ONNX graph can still be inline after this.

opset: domain='' version=18
opset: domain='aten_local_function' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='mask' type=float32 shape=(256, 256)                       -- GraphBuilder.make_local_function/from(mask)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16)       -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='mask2' type=float32 shape=(256, 256)                      -- GraphBuilder.make_local_function/from(mask2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='_onx_transpose_weight05' type=float32 shape=(32, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight05)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='_onx_transpose_weight06' type=float32 shape=(16, 128)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight06)
init: name='_onx_transpose_weight023' type=float32 shape=(128, 16)    -- GraphBuilder.make_local_function/from(_onx_transpose_weight023)
Constant(value=[1.0, 1.0,...) -> init1_s16_
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  Add(embedding2, pe) -> embedding
Constant(value=[0.0, 0.0,...) -> init1_s16_2
  LayerNormalization(embedding, init1_s16_, init1_s16_2, axis=-1, epsilon=0.00, stash_type=1) -> norm_1
    MatMul(norm_1, _onx_transpose_weight0) -> query
Constant(value=[-inf]) -> init1_s1_
Constant(value=[0.25]) -> _reshape_init1_s_0
Constant(value=[0.0]) -> _reshape_init1_s_20
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis
  Slice(mask, SliceSlicePattern_init7_s1_0_start, SliceSlicePattern_init7_s1_30_end, SliceSlicePattern_init7_s1_1_axis) -> slice_2
  Equal(slice_2, _reshape_init1_s_20) -> eq
MatMul(norm_1, _onx_transpose_weight02) -> key
  Transpose(key, perm=[0,2,1]) -> transpose
    MatMul(query, transpose) -> matmul
  Mul(matmul, _reshape_init1_s_0) -> _onx_mul_matmul0
  Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, _onx_transpose_weight03) -> value
  MatMul(softmax, value) -> attention_0
Constant(value=[-inf]) -> init1_s1_2
Constant(value=[0.25]) -> _reshape_init1_s_02
Constant(value=[0.0]) -> _reshape_init1_s_202
Constant(value=[0, 0]) -> SliceSlicePattern_init7_s1_0_start2
Constant(value=[30, 30]) -> SliceSlicePattern_init7_s1_30_end2
Constant(value=[0, 1]) -> SliceSlicePattern_init7_s1_1_axis2
  Slice(mask2, SliceSlicePattern_init7_s1_0_start2, SliceSlicePattern_init7_s1_30_end2, SliceSlicePattern_init7_s1_1_axis2) -> slice_22
  Equal(slice_22, _reshape_init1_s_202) -> eq2
MatMul(norm_1, _onx_transpose_weight04) -> query2
MatMul(norm_1, _onx_transpose_weight022) -> key2
  Transpose(key2, perm=[0,2,1]) -> transpose2
  MatMul(query2, transpose2) -> matmul2
  Mul(matmul2, _reshape_init1_s_02) -> _onx_mul_matmul02
  Where(eq2, init1_s1_2, _onx_mul_matmul02) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
MatMul(norm_1, _onx_transpose_weight032) -> value2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, _onx_transpose_weight05) -> _onx_matmul_cat0
Constant(value=[0.0656113...) -> bias
  Add(_onx_matmul_cat0, bias) -> attention
    Add(attention, embedding) -> add_1
Constant(value=[1.0, 1.0,...) -> init1_s16_3
Constant(value=[0.0, 0.0,...) -> init1_s16_22
  LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
    MatMul(norm_2, _onx_transpose_weight06) -> _onx_matmul_layer_norm_10
      Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_1
        Relu(linear_1) -> relu
          MatMul(relu, _onx_transpose_weight023) -> _onx_matmul_relu0
Constant(value=[-0.025308...) -> bias2
  Add(_onx_matmul_relu0, bias2) -> feed_forward
    Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Optimizations

The ONNX graph produced by the exporter without any optimization is very verbose and less efficient. That’s why some optimizations are made to the model by default. It is also possible to introduce kernels implemented in onnxruntime. Let’s how it goes.

onx_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2
    ),
)
print(pretty_onnx(onx_optimized))
[GraphBuilder-ZWA.optimize] start with 73 nodes
[GraphBuilder-ZWA.optimize] #patterns=66
[GraphBuilder-ZWA.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-ZWA.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-ZWA.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-ZWA.optimize] start with 53 nodes, 28 initializers, 66 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   1/66 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   2/66 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   3/66 - P0 - CastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   4/66 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   5/66 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   6/66 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   7/66 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   8/66 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern   9/66 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  10/66 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  11/66 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  12/66 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  13/66 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  14/66 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  15/66 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  16/66 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  17/66 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  18/66 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  19/66 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  20/66 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  21/66 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  22/66 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  23/66 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  24/66 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  25/66 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  26/66 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  27/66 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  28/66 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  29/66 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  30/66 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  31/66 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  32/66 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  33/66 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  34/66 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  35/66 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  36/66 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  37/66 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  38/66 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  39/66 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  40/66 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  41/66 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  42/66 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  43/66 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  44/66 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  45/66 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  46/66 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  47/66 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  48/66 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  49/66 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  50/66 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  51/66 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  52/66 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  53/66 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  54/66 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  55/66 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  56/66 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  57/66 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  58/66 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  59/66 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  60/66 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  61/66 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  62/66 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  63/66 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  64/66 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  65/66 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] use pattern  66/66 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-ZWA.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.006 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-ZWA.optimize] increase priority to 1
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-ZWA.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-ZWA.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-ZWA.optimize] increase priority to 2
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-ZWA.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-ZWA.optimize] increase priority to 3
[GraphBuilderPatternOptimization-ZWA.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-ZWA.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-ZWA.optimize] done after 8 iterations with 29 nodes in 0.033
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.000173841996002011
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.00043014800030505285
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00024007700267247856
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.00044869499834021553
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00013733400555793196
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014058530068723485
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0001476519973948598
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0012261480078450404
    STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.00028686600126093253
    STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.561399742262438e-05
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0002683900165720843
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00021369899332057685
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013751999358646572
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00012267099373275414
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00042915101221296936
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002062719941022806
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00039391099562635645
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.0002374749892624095
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014003500109538436
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00023889500880613923
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00020226900232955813
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012630200217245147
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001336939967586659
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0001983599941013381
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00012818598770536482
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013789199147140607
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.229200516827404e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00018165100482292473
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00036033200012752786
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.089899397920817e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.9845006035175174e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0017957130112336017
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0026028189895441756
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=6.459995347540826e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001596389920450747
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.00223082799493568
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00026810899726115167
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001514590039732866
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.002048662005108781
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.953899380983785e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006744000129401684
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003577850147848949
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002119630080414936
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002014479978242889
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013585301348939538
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018921001174021512
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018153598648495972
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004234040097799152
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.0284998754505068e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00030415500077651814
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00020345000666566193
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00027579999732552096
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002137810006388463
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019520599016686901
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0006518419977510348
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014285900397226214
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.000290568990749307
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.00020128698815824464
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001317569985985756
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013952799781691283
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00408375401457306
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001676660103839822
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001409980104654096
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00022318900300888345
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00019362599414307624
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00042318199848523363
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002227019940619357
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00021736099006375298
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.86179948085919e-05
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032966199796646833
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034290101029910147
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002329110211576335
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033013800566550344
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002353209929424338
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002007650036830455
    STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0007423379938700236
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  29 x 1t
          INIT:   1 x 7t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   8 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   7 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      INIT:   1 x 7t[1]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-ZWA.remove_unused] remove_initializer 1:5/30:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 2:6/30:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 3:7/30:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 4:8/30:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-ZWA.remove_unused] remove_initializer 5:9/30:init7_s1_-1:int64[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 6:10/30:init1_s1_:float32[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 7:11/30:init1_s1_2:float32[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 8:16/30:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-ZWA.remove_unused] remove_initializer 9:22/30:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-ZWA.optimize] done with 29 nodes in 0.042
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add00, unused, unused2, add
    MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
  Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, _onx_transpose_p_decoder_attention_linear_weight0) -> _onx_matmul_cat0
        Add(_onx_matmul_cat0, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_100, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_100, _onx_transpose_p_decoder_feed_forward_linear_1_weight0) -> _onx_matmul_layer_norm_10
        Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, _onx_transpose_p_decoder_feed_forward_linear_2_weight0) -> _onx_matmul_relu0
              Add(_onx_matmul_relu0, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

This shows a kernel FusedMatMul[com.microsoft] which implement a kernel equivalent Gemm but working for any tensors, not only 2D. How does it work on the model which keeps exports the moduels as local functions? The optimizer optimizes every local function independantly. We reduce the verbosity…

onx_module_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(patterns="default+onnxruntime", constant_folding=True),
    export_modules_as_functions=True,
)
print(pretty_onnx(onx_module_optimized))
opset: domain='' version=18
opset: domain='aten_local_function' version=1
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- GraphBuilder.make_local_function/from(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- GraphBuilder.make_local_function/from(embedding.pe.weight)
init: name='_onx_transpose_weight0' type=float32 shape=(16, 16)       -- GraphBuilder.make_local_function/from(_onx_transpose_weight0)
init: name='_onx_transpose_weight02' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight02)
init: name='_onx_transpose_weight03' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight03)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_2)
init: name='_onx_transpose_weight04' type=float32 shape=(16, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight04)
init: name='_onx_transpose_weight022' type=float32 shape=(16, 16)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight022)
init: name='_onx_transpose_weight032' type=float32 shape=(16, 16)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight032)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.make_local_function/from(slice_4)
init: name='_onx_transpose_weight05' type=float32 shape=(32, 16)      -- GraphBuilder.make_local_function/from(_onx_transpose_weight05)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- GraphBuilder.make_local_function/from(decoder.feed_forward.linear_1.bias)
init: name='_onx_transpose_weight06' type=float32 shape=(16, 128)     -- GraphBuilder.make_local_function/from(_onx_transpose_weight06)
init: name='_onx_transpose_weight023' type=float32 shape=(128, 16)    -- GraphBuilder.make_local_function/from(_onx_transpose_weight023)
init: name='init1_s16_3' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_22' type=float32 shape=(16,)                    -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_' type=float32 shape=(16,)                      -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s16_2' type=float32 shape=(16,)                     -- GraphBuilder.constant_folding.from/fold()
init: name='bias2' type=float32 shape=(16,)                           -- GraphBuilder.constant_folding.from/fold()
init: name='bias' type=float32 shape=(16,)                            -- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_2' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='init1_s1_' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold()
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding2
Gather(embedding.pe.weight, input_ids) -> pe
  SkipLayerNormalization[com.microsoft](embedding2, pe, init1_s16_, init1_s16_2, epsilon=0.00) -> norm_1, unused, unused2, embedding
    MatMul(norm_1, _onx_transpose_weight0) -> query
MatMul(norm_1, _onx_transpose_weight02) -> key
  FusedMatMul[com.microsoft](query, key, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
  Where(eq, init1_s1_, _onx_mul_matmul0) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(norm_1, _onx_transpose_weight03) -> value
  MatMul(softmax, value) -> attention_0
MatMul(norm_1, _onx_transpose_weight04) -> query2
MatMul(norm_1, _onx_transpose_weight022) -> key2
  FusedMatMul[com.microsoft](query2, key2, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul02
MatMul(norm_1, _onx_transpose_weight032) -> value2
Equal(slice_4, _reshape_init1_s_202) -> eq2
  Where(eq2, init1_s1_2, _onx_mul_matmul02) -> masked_fill2
    Softmax(masked_fill2, axis=-1) -> softmax2
  MatMul(softmax2, value2) -> attention_1
    Concat(attention_0, attention_1, axis=-1) -> cat
      MatMul(cat, _onx_transpose_weight05) -> _onx_matmul_cat0
        Add(_onx_matmul_cat0, bias) -> attention
    Add(attention, embedding) -> add_1
      LayerNormalization(add_1, init1_s16_3, init1_s16_22, axis=-1, epsilon=0.00, stash_type=1) -> norm_2
        MatMul(norm_2, _onx_transpose_weight06) -> _onx_matmul_layer_norm_10
          Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_1
            Relu(linear_1) -> relu
              MatMul(relu, _onx_transpose_weight023) -> _onx_matmul_relu0
                Add(_onx_matmul_relu0, bias2) -> feed_forward
      Add(feed_forward, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

It seems to be working as well on this simple case even though the optimizers were not tested on such models. However, keeping the submodule information might be useful to implement optimizer for a fmaily of models sharing the same components.

Optimizations for CUDA

The optimizer may have a different behaviour knowning the model is running on CUDA. It may use different kernels and do different optimization if needed. That may not be the good place to do it as the runtime may choose to run one kernel on CPU, another one on CUDA. The current optimization does not know that and is not able to decide which provider would be more useful for some kernels. This coudl even be decided at runtime.

onx_cuda_optimized = to_onnx(
    llm,
    (input_ids,),
    options=OptimizationOptions(
        patterns="default+onnxruntime", constant_folding=True, verbose=2, processor="CUDA"
    ),
)
print(pretty_onnx(onx_cuda_optimized))
[GraphBuilder-VNC.optimize] start with 73 nodes
[GraphBuilder-VNC.optimize] #patterns=66
[GraphBuilder-VNC.remove_unused] remove_initializer 1:1/47:embedding.embedding.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 2:3/47:embedding.pe.weight:torch.float32[torch.Size([1024, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 3:5/47:decoder.attention.attention.0.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 4:7/47:decoder.attention.attention.0.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 5:9/47:decoder.attention.attention.0.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 6:11/47:decoder.attention.attention.1.query.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 7:13/47:decoder.attention.attention.1.key.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 8:15/47:decoder.attention.attention.1.value.weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 9:17/47:decoder.attention.linear.weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-VNC.remove_unused] remove_initializer 10:19/47:decoder.attention.linear.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 11:21/47:decoder.feed_forward.linear_1.weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 12:23/47:decoder.feed_forward.linear_1.bias:torch.float32[torch.Size([128])]
[GraphBuilder-VNC.remove_unused] remove_initializer 13:25/47:decoder.feed_forward.linear_2.weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VNC.remove_unused] remove_initializer 14:27/47:decoder.feed_forward.linear_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 15:29/47:decoder.norm_1.weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 16:31/47:decoder.norm_1.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 17:33/47:decoder.norm_2.weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 18:35/47:decoder.norm_2.bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 1:2/46:p_decoder_attention_attention_0_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 2:3/46:p_decoder_attention_attention_0_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 3:4/46:p_decoder_attention_attention_0_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 4:5/46:p_decoder_attention_attention_1_query_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 5:6/46:p_decoder_attention_attention_1_key_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 6:7/46:p_decoder_attention_attention_1_value_weight:torch.float32[torch.Size([16, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 7:8/46:p_decoder_attention_linear_weight:torch.float32[torch.Size([16, 32])]
[GraphBuilder-VNC.remove_unused] remove_initializer 8:10/46:p_decoder_feed_forward_linear_1_weight:torch.float32[torch.Size([128, 16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 9:12/46:p_decoder_feed_forward_linear_2_weight:torch.float32[torch.Size([16, 128])]
[GraphBuilder-VNC.remove_unused] remove_initializer 10:18/46:b_decoder_attention_attention_0_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-VNC.remove_unused] remove_initializer 11:19/46:b_decoder_attention_attention_1_mask:torch.float32[torch.Size([256, 256])]
[GraphBuilder-VNC.remove_unused] remove_initializer 12:23/46:init1_s_:float32[()]
[GraphBuilder-VNC.remove_unused] remove_initializer 13:24/46:init7_s1_1:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 14:25/46:init7_s1_0:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 15:26/46:init7_s1_30:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 16:27/46:init1_s_2:float32[()]
[GraphBuilder-VNC.remove_unused] remove_initializer 17:33/46:slice_1:torch.float32[torch.Size([30, 256])]
[GraphBuilder-VNC.remove_unused] remove_initializer 18:40/46:slice_3:torch.float32[torch.Size([30, 256])]
[GraphBuilderPatternOptimization-VNC.optimize] start with 53 nodes, 28 initializers, 66 patterns, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   1/66 - P0 - BatchNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   2/66 - P0 - BatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   3/66 - P0 - CastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   4/66 - P0 - ConvBiasNullPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   5/66 - P0 - ExpandPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   6/66 - P0 - GeluErfPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   7/66 - P0 - GeluOrtPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   8/66 - P0 - GeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern   9/66 - P0 - IdentityPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  10/66 - P0 - LeakyReluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  11/66 - P0 - ReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  12/66 - P0 - ReshapeReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  13/66 - P0 - SameChildrenPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  14/66 - P0 - SoftmaxCrossEntropyLossCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  15/66 - P0 - SqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  16/66 - P0 - TransposeReshapeTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  17/66 - P0 - TransposeTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  18/66 - P0 - UnsqueezeUnsqueezePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  19/66 - P1 - BiasGeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  20/66 - P1 - BiasSoftmaxPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  21/66 - P1 - CastCastBinaryPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  22/66 - P1 - CastLayerNormalizationCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  23/66 - P1 - CastOpCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  24/66 - P1 - ClipClipPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  25/66 - P1 - ComputationCastOpCastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  26/66 - P1 - DropoutPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  27/66 - P1 - ExpandBroadcastPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  28/66 - P1 - ExpandSwapPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  29/66 - P1 - FastGeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  30/66 - P1 - GemmTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  31/66 - P1 - LayerNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  32/66 - P1 - LayerNormalizationScalePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  33/66 - P1 - MatMulReshape2Of3Pattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  34/66 - P1 - MulMulMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  35/66 - P1 - MulMulMulScalarPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  36/66 - P1 - OrtBatchNormalizationTrainingPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  37/66 - P1 - QuickGeluPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  38/66 - P1 - ReduceReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  39/66 - P1 - ReduceSumNormalizePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  40/66 - P1 - Reshape2Of3Pattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  41/66 - P1 - ReshapeMatMulReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  42/66 - P1 - ReshapeReshapeBinaryPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  43/66 - P1 - RotaryConcatPartPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  44/66 - P1 - SequenceConstructAtPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  45/66 - P1 - SimplifiedLayerNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  46/66 - P1 - SkipLayerNormalizationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  47/66 - P1 - SliceSlicePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  48/66 - P1 - SlicesSplitPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  49/66 - P1 - SoftmaxGradPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  50/66 - P1 - SplitConcatPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  51/66 - P1 - Sub1MulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  52/66 - P1 - SwitchOrderBinaryPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  53/66 - P1 - SwitchReshapeActivationPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  54/66 - P1 - TransposeEqualReshapePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  55/66 - P1 - TransposeMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  56/66 - P1 - TransposeReshapeMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  57/66 - P1 - UnsqueezeEqualPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  58/66 - P2 - FusedConvPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  59/66 - P2 - FusedMatMulDivPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  60/66 - P2 - FusedMatMulPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  61/66 - P3 - AttentionPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  62/66 - P3 - FusedMatMulTransposePattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  63/66 - P3 - FusedMatMulx2Pattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  64/66 - P3 - MatMulAddPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  65/66 - P3 - ReshapeGemmPattern()
[GraphBuilderPatternOptimization-VNC.optimize] use pattern  66/66 - P3 - TransposeFusedMatMulBPattern()
[GraphBuilderPatternOptimization-VNC.optimize] iteration 0: 53 nodes, priority=0
[GraphBuilderPatternOptimization-VNC.optimize] applies 6 matches, 2*CastPattern, 4*IdentityPattern - time=0.008 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-VNC.optimize] iteration 1: 47 nodes, priority=0
[GraphBuilderPatternOptimization-VNC.optimize] increase priority to 1
[GraphBuilderPatternOptimization-VNC.optimize] iteration 2: 47 nodes, priority=1
[GraphBuilderPatternOptimization-VNC.optimize] applies 2 matches, 2*LayerNormalizationPattern - time=0.004 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VNC.optimize] iteration 3: 35 nodes, priority=1
[GraphBuilderPatternOptimization-VNC.optimize] applies 2 matches, 2*SkipLayerNormalizationPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VNC.optimize] iteration 4: 33 nodes, priority=1
[GraphBuilderPatternOptimization-VNC.optimize] increase priority to 2
[GraphBuilderPatternOptimization-VNC.optimize] iteration 5: 33 nodes, priority=2
[GraphBuilderPatternOptimization-VNC.optimize] applies 2 matches, 2*FusedMatMulPattern - time=0.003 | max_time=IdentityPattern:0.000
[GraphBuilderPatternOptimization-VNC.optimize] iteration 6: 29 nodes, priority=2
[GraphBuilderPatternOptimization-VNC.optimize] increase priority to 3
[GraphBuilderPatternOptimization-VNC.optimize] iteration 7: 29 nodes, priority=3
[GraphBuilderPatternOptimization-VNC.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-VNC.optimize] done after 8 iterations with 29 nodes in 0.038
    STAT apply_CastPattern +2 -2 #it=1 maxmatch=1 i=2 - time=0.00026056999195134267
    STAT apply_FusedMatMulPattern +2 -6 #it=1 maxmatch=1 i=2 - time=0.0005157869964023121
    STAT apply_IdentityPattern +4 -4 #it=1 maxmatch=5 i=4 - time=0.00030466200405498967
    STAT apply_LayerNormalizationPattern +2 -14 #it=1 maxmatch=1 i=2 - time=0.0005376979970606044
    STAT apply_SkipLayerNormalizationPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00022099500347394496
    STAT build_graph_for_pattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0014846519916318357
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00021267700503813103
    STAT check_pattern_A0 +0 -0 #it=4 maxmatch=0 i=0 - time=0.0015150120161706582
    STAT check_pattern_B0 +0 -0 #it=3 maxmatch=0 i=0 - time=0.0003633899978012778
    STAT match_AttentionPattern +0 -0 #it=1 maxmatch=0 i=0 - time=2.4903994926717132e-05
    STAT match_BatchNormalizationPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.0004113059985684231
    STAT match_BatchNormalizationTrainingPattern +0 -0 #it=8 maxmatch=0 i=0 - time=0.00028710500191664323
    STAT match_BiasGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001591440086485818
    STAT match_BiasSoftmaxPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002895809957408346
    STAT match_CastCastBinaryPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004520799921010621
    STAT match_CastLayerNormalizationCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0002281540000694804
    STAT match_CastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0004967470013070852
    STAT match_CastPattern +0 -0 #it=8 maxmatch=2 i=2 - time=0.00029728899971814826
    STAT match_ClipClipPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00015165000513661653
    STAT match_ComputationCastOpCastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00026441601221449673
    STAT match_ConvBiasNullPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.00024405600561294705
    STAT match_DropoutPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001320449955528602
    STAT match_ExpandBroadcastPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.0001397379965055734
    STAT match_ExpandPattern +0 -0 #it=8 maxmatch=2 i=0 - time=0.0002740909985732287
    STAT match_ExpandSwapPattern +0 -0 #it=6 maxmatch=0 i=0 - time=0.00014128500333754346
    STAT match_FastGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014789899432798848
    STAT match_FusedConvPattern +0 -0 #it=3 maxmatch=0 i=0 - time=6.449100328609347e-05
    STAT match_FusedMatMulDivPattern +0 -0 #it=3 maxmatch=2 i=0 - time=0.00014306599769042805
    STAT match_FusedMatMulPattern +0 -0 #it=3 maxmatch=2 i=2 - time=0.00035403600486461073
    STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=7.892500434536487e-05
    STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.7320998166687787e-05
    STAT match_GeluErfPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00228474000323331
    STAT match_GeluOrtPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.003075333996093832
    STAT match_GeluPattern +0 -0 #it=8 maxmatch=2 i=0 - time=7.80199479777366e-06
    STAT match_GemmTransposePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001459740087739192
    STAT match_IdentityPattern +0 -0 #it=8 maxmatch=6 i=4 - time=0.0026902310055447742
    STAT match_LayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0003888280116370879
    STAT match_LayerNormalizationScalePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00018430400814395398
    STAT match_LeakyReluPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0028149299978394993
    STAT match_MatMulAddPattern +0 -0 #it=1 maxmatch=0 i=0 - time=6.969700189074501e-05
    STAT match_MatMulReshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0006638829945586622
    STAT match_MulMulMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00035725799534702674
    STAT match_MulMulMulScalarPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002011279866565019
    STAT match_OrtBatchNormalizationTrainingPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00024763400870142505
    STAT match_QuickGeluPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0001439319967175834
    STAT match_ReduceReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020257899450371042
    STAT match_ReduceSumNormalizePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015912499657133594
    STAT match_Reshape2Of3Pattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0004544649927993305
    STAT match_ReshapeGemmPattern +0 -0 #it=1 maxmatch=0 i=0 - time=1.912900188472122e-05
    STAT match_ReshapeMatMulReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00032406898390036076
    STAT match_ReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00033377300132997334
    STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0002825109986588359
    STAT match_ReshapeReshapePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026791998971020803
    STAT match_RotaryConcatPartPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00020742399647133425
    STAT match_SameChildrenPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0007477519975509495
    STAT match_SequenceConstructAtPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014944799477234483
    STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00028798499261029065
    STAT match_SkipLayerNormalizationPattern +0 -0 #it=6 maxmatch=2 i=2 - time=0.0001661589994910173
    STAT match_SliceSlicePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013994799519423395
    STAT match_SlicesSplitPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00015538500883849338
    STAT match_SoftmaxCrossEntropyLossCastPattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.004016680999484379
    STAT match_SoftmaxGradPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00013662099809153005
    STAT match_SplitConcatPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00014305799413705245
    STAT match_SqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002699020114960149
    STAT match_Sub1MulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00016476198652526364
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00046992699935799465
    STAT match_SwitchReshapeActivationPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022256100055528805
    STAT match_TransposeEqualReshapePattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00022111900034360588
    STAT match_TransposeFusedMatMulBPattern +0 -0 #it=1 maxmatch=0 i=0 - time=9.120399772655219e-05
    STAT match_TransposeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.0003470090086921118
    STAT match_TransposeReshapeMatMulPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00034353700175415725
    STAT match_TransposeReshapeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026978999812854454
    STAT match_TransposeTransposePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.00026462300593266264
    STAT match_UnsqueezeEqualPattern +0 -0 #it=6 maxmatch=2 i=0 - time=0.00023889799922471866
    STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=8 maxmatch=6 i=0 - time=0.0002689320099307224
    STAT remove_identity_nodes +9 -15 #it=3 maxmatch=0 i=0 - time=0.0008166290062945336
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--
         INPUT:   1 x 7t
     INPUT-SEQ:   1 x Falset
        OUTPUT:   1 x 1t
    OUTPUT-SEQ:   1 x Falset
          INIT:  29 x 1t
          INIT:   1 x 7t
          NODE:   4 x Add
          NODE:   1 x Concat
          NODE:   2 x Equal
          NODE:   2 x Gather
          NODE:  11 x MatMul
          NODE:   1 x Relu
          NODE:   2 x Softmax
          NODE:   2 x Where
          NODE:   2 x com.microsoft.FusedMatMul
          NODE:   2 x com.microsoft.SkipLayerNormalization
--MODEL: 29 nodes, 1 inputs, 1 outputs, 30 initializers--DETAILED--
     INPUT:   1 x 7t[1x30]
    OUTPUT:   1 x 1t[1x30x16]
      INIT:   2 x 1t[1024x16]
      INIT:   1 x 1t[128]
      INIT:   1 x 1t[128x16]
      INIT:   8 x 1t[16]
      INIT:   1 x 1t[16x128]
      INIT:   6 x 1t[16x16]
      INIT:   7 x 1t[1]
      INIT:   2 x 1t[30x30]
      INIT:   1 x 1t[32x16]
      INIT:   1 x 7t[1]
      NODE:   1 x Add -SIG- 1t[1x30x128], 1t[128]
      NODE:   2 x Add -SIG- 1t[1x30x16], 1t[16]
      NODE:   1 x Add -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   1 x Concat -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x Equal -SIG- 1t[30x30], 1t[1]
      NODE:   2 x Gather -SIG- 1t[1024x16], 7t[1x30]
      NODE:   1 x MatMul -SIG- 1t[1x30x128], 1t[128x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x16], 1t[16x128]
      NODE:   6 x MatMul -SIG- 1t[1x30x16], 1t[16x16]
      NODE:   2 x MatMul -SIG- 1t[1x30x30], 1t[1x30x16]
      NODE:   1 x MatMul -SIG- 1t[1x30x32], 1t[32x16]
      NODE:   1 x Relu -SIG- 1t[1x30x128]
      NODE:   2 x Softmax -SIG- 1t[1x30x30]
      NODE:   2 x Where -SIG- 9t[30x30], 1t[1], 1t[1x30x30]
      NODE:   2 x com.microsoft.FusedMatMul -SIG- 1t[1x30x16], 1t[1x30x16]
      NODE:   2 x com.microsoft.SkipLayerNormalization -SIG- 1t[1x30x16], 1t[1x30x16], 1t[16], 1t[16]
[GraphBuilder-VNC.remove_unused] remove_initializer 1:5/30:p_decoder_norm_1_weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 2:6/30:p_decoder_norm_1_bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 3:7/30:p_decoder_norm_2_weight:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 4:8/30:p_decoder_norm_2_bias:torch.float32[torch.Size([16])]
[GraphBuilder-VNC.remove_unused] remove_initializer 5:9/30:init7_s1_-1:int64[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 6:10/30:init1_s1_:float32[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 7:11/30:init1_s1_2:float32[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 8:16/30:_reshape_init1_s_0:float32[(1,)]
[GraphBuilder-VNC.remove_unused] remove_initializer 9:22/30:_reshape_init1_s_02:float32[(1,)]
[GraphBuilder-VNC.optimize] done with 29 nodes in 0.050
opset: domain='' version=18
opset: domain='com.microsoft' version=1
input: name='input_ids' type=dtype('int64') shape=[1, 30]
init: name='init1_s1_3' type=float32 shape=(1,) -- array([-inf], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='_onx_transpose_p_decoder_attention_attention_0_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_query_weight)##p_decoder_attention_attention_0_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_key_weight)##p_decoder_attention_attention_0_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_0_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_0_value_weight)##p_decoder_attention_attention_0_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.0.value.weight)
init: name='slice_2' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_1)##slice_1/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_0_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_0_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_20' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_attention_1_query_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_query_weight)##p_decoder_attention_attention_1_query_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.query.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_key_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_key_weight)##p_decoder_attention_attention_1_key_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.key.weight)
init: name='_onx_transpose_p_decoder_attention_attention_1_value_weight0' type=float32 shape=(16, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_attention_1_value_weight)##p_decoder_attention_attention_1_value_weight/DynamoInterpret.placeholder.1/P(decoder.attention.attention.1.value.weight)
init: name='slice_4' type=float32 shape=(30, 30)                      -- GraphBuilder.constant_folding.from/fold(init7_s1_0,init7_s1_1,init7_s1_30,slice_3)##slice_3/GraphBuilder.constant_folding.from/fold(b_decoder_attention_attention_1_mask,init7_s1_0,init7_s1_30)##b_decoder_attention_attention_1_mask/DynamoInterpret.placeholder.0##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_0/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_30/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_reshape_init1_s_202' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='_onx_transpose_p_decoder_attention_linear_weight0' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_attention_linear_weight)##p_decoder_attention_linear_weight/DynamoInterpret.placeholder.1/P(decoder.attention.linear.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_1_weight0' type=float32 shape=(16, 128)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_1_weight)##p_decoder_feed_forward_linear_1_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.weight)
init: name='_onx_transpose_p_decoder_feed_forward_linear_2_weight0' type=float32 shape=(128, 16)-- GraphBuilder.constant_folding.from/fold(p_decoder_feed_forward_linear_2_weight)##p_decoder_feed_forward_linear_2_weight/DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.weight)
init: name='init1_s16_' type=float32 shape=(16,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s16_2' type=float32 shape=(16,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='embedding.embedding.weight' type=float32 shape=(1024, 16) -- DynamoInterpret.placeholder.1/P(embedding.embedding.weight)
init: name='embedding.pe.weight' type=float32 shape=(1024, 16)        -- DynamoInterpret.placeholder.1/P(embedding.pe.weight)
init: name='decoder.attention.linear.bias' type=float32 shape=(16,)   -- DynamoInterpret.placeholder.1/P(decoder.attention.linear.bias)
init: name='decoder.feed_forward.linear_1.bias' type=float32 shape=(128,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_1.bias)
init: name='decoder.feed_forward.linear_2.bias' type=float32 shape=(16,)-- DynamoInterpret.placeholder.1/P(decoder.feed_forward.linear_2.bias)
Equal(slice_2, _reshape_init1_s_20) -> eq
Gather(embedding.embedding.weight, input_ids) -> embedding
Gather(embedding.pe.weight, input_ids) -> embedding_1
  SkipLayerNormalization[com.microsoft](embedding, embedding_1, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add00, unused, unused2, add
    MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_query_weight0) -> linear
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_key_weight0) -> linear_1
  FusedMatMul[com.microsoft](linear, linear_1, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul0
  Where(eq, init1_s1_3, _onx_mul_matmul0) -> masked_fill
    Softmax(masked_fill, axis=-1) -> softmax
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_0_value_weight0) -> linear_2
  MatMul(softmax, linear_2) -> matmul_1
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_query_weight0) -> linear_3
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_key_weight0) -> linear_4
  FusedMatMul[com.microsoft](linear_3, linear_4, alpha=0.25, transA=0, transB=1, transBatchA=0, transBatchB=0) -> _onx_mul_matmul_20
MatMul(_onx_div_sub_add00, _onx_transpose_p_decoder_attention_attention_1_value_weight0) -> linear_5
Equal(slice_4, _reshape_init1_s_202) -> eq_1
  Where(eq_1, init1_s1_3, _onx_mul_matmul_20) -> masked_fill_1
    Softmax(masked_fill_1, axis=-1) -> softmax_1
  MatMul(softmax_1, linear_5) -> matmul_3
    Concat(matmul_1, matmul_3, axis=-1) -> cat
      MatMul(cat, _onx_transpose_p_decoder_attention_linear_weight0) -> _onx_matmul_cat0
        Add(_onx_matmul_cat0, decoder.attention.linear.bias) -> linear_6
    SkipLayerNormalization[com.microsoft](linear_6, add, init1_s16_, init1_s16_2, epsilon=0.00) -> _onx_div_sub_add_100, unused3, unused4, add_1
      MatMul(_onx_div_sub_add_100, _onx_transpose_p_decoder_feed_forward_linear_1_weight0) -> _onx_matmul_layer_norm_10
        Add(_onx_matmul_layer_norm_10, decoder.feed_forward.linear_1.bias) -> linear_7
          Relu(linear_7) -> relu
            MatMul(relu, _onx_transpose_p_decoder_feed_forward_linear_2_weight0) -> _onx_matmul_relu0
              Add(_onx_matmul_relu0, decoder.feed_forward.linear_2.bias) -> linear_8
      Add(linear_8, add_1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 30, 16]

Comparison optimized and not optimized?

The following tools is trying to match the node and shape inference from two models. If they are not too different, the functions is able to find out the differences. We can use to see which operators were fused into bigger ones only implemented by onnxruntime.

res1, res2, align, dc = compare_onnx_execution(
    onx, onx_optimized, verbose=1, cls=ExtendedReferenceEvaluator
)
print("------------")
text = dc.to_str(res1, res2, align)
print(text)
[compare_onnx_execution] generate inputs
[compare_onnx_execution] execute with 1 inputs
[compare_onnx_execution] execute first model
[compare_onnx_execution] got 68 results
[compare_onnx_execution] execute second model
[compare_onnx_execution] got 68 results (first model)
[compare_onnx_execution] got 58 results (second model)
[compare_onnx_execution] compute edit distance
[compare_onnx_execution] got 68 pairs
[compare_onnx_execution] done
------------
001 ~ | INITIA float32  2:256x256            AOCQ                 b_ | INITIA float32  1:1                  ?AAA                 in
002 - | INITIA float32  2:256x256            AOCQ                 b_ |
003 - | INITIA float32  1:1                  ?AAA                 in |
004 = | INITIA float32  2:16x16              ABAB                 _o | INITIA float32  2:16x16              ABAB                 _o
005 = | INITIA float32  2:16x16              AAAA                 _o | INITIA float32  2:16x16              AAAA                 _o
006 = | INITIA float32  2:16x16              ABAB                 _o | INITIA float32  2:16x16              ABAB                 _o
007 ~ | INITIA float32  1:1                  AAAA                 _r | INITIA float32  2:30x30              KGSP                 sl
008 = | INITIA float32  1:1                  AAAA                 _r | INITIA float32  1:1                  AAAA                 _r
009 = | INITIA float32  2:16x16              AACB                 _o | INITIA float32  2:16x16              AACB                 _o
010 = | INITIA float32  2:16x16              BBAZ                 _o | INITIA float32  2:16x16              BBAZ                 _o
011 = | INITIA float32  2:16x16              BZYZ                 _o | INITIA float32  2:16x16              BZYZ                 _o
012 ~ | INITIA float32  1:1                  AAAA                 _r | INITIA float32  2:30x30              KGSP                 sl
013 = | INITIA float32  1:1                  AAAA                 _r | INITIA float32  1:1                  AAAA                 _r
014 = | INITIA float32  2:32x16              ZAAB                 _o | INITIA float32  2:32x16              ZAAB                 _o
015 = | INITIA float32  2:16x128             AFZT                 _o | INITIA float32  2:16x128             AFZT                 _o
016 = | INITIA float32  2:128x16             AAAB                 _o | INITIA float32  2:128x16             AAAB                 _o
017 = | INITIA float32  1:16                 EEEE                 in | INITIA float32  1:16                 EEEE                 in
018 = | INITIA float32  1:16                 AAAA                 in | INITIA float32  1:16                 AAAA                 in
019 - | INITIA int64    1:2                  AAAA                 Sl |
020 - | INITIA int64    1:2                  EEAA                 Sl |
021 - | INITIA int64    1:2                  ABAA                 Sl |
022 - | INITIA int64    1:2                  AAAA                 Sl |
023 - | INITIA int64    1:2                  EEAA                 Sl |
024 - | INITIA int64    1:2                  ABAA                 Sl |
025 = | INITIA float32  2:1024x16            SCTO                 em | INITIA float32  2:1024x16            SCTO                 em
026 = | INITIA float32  2:1024x16            BMVK                 em | INITIA float32  2:1024x16            BMVK                 em
027 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
028 = | INITIA float32  1:128                AAAA                 de | INITIA float32  1:128                AAAA                 de
029 = | INITIA float32  1:16                 AAAA                 de | INITIA float32  1:16                 AAAA                 de
030 = | INPUT  int64    2:1x30               COAD                 in | INPUT  int64    2:1x30               COAD                 in
031 = | RESULT float32  3:1x30x16            ASVG Gather          em | RESULT float32  3:1x30x16            ASVG Gather          em
032 = | RESULT float32  3:1x30x16            HATW Gather          em | RESULT float32  3:1x30x16            HATW Gather          em
033 ~ | RESULT float32  3:1x30x16            ISNB Add             ad | RESULT float32  3:1x30x16            CYAA SkipLayerNormal _o
034 ~ | RESULT float32  3:1x30x16            CYAA LayerNormalizat _o | RESULT float32  3:1x30x1             ACAA SkipLayerNormal un
035 ~ | RESULT float32  3:1x30x16            WAXA MatMul          li | RESULT float32  3:1x30x1             GFGE SkipLayerNormal un
036 ~ | RESULT float32  3:1x30x16            FGFA MatMul          li | RESULT float32  3:1x30x16            ISNB SkipLayerNormal ad
037 ~ | RESULT float32  3:1x30x16            XTUU MatMul          li | RESULT float32  3:1x30x16            WAXA MatMul          li
038 ~ | RESULT float32  3:1x16x30            NFYC Transpose       tr | RESULT float32  3:1x30x16            FGFA MatMul          li
039 ~ | RESULT float32  3:1x30x30            EAHJ MatMul          ma | RESULT float32  3:1x30x30            BUIC FusedMatMul     _o
040 ~ | RESULT float32  3:1x30x30            BUIC Mul             _o | RESULT float32  3:1x30x16            XTUU MatMul          li
041 - | RESULT float32  2:30x30              KGSP Slice           sl |
042 = | RESULT bool     2:30x30              HLZC Equal           eq | RESULT bool     2:30x30              HLZC Equal           eq
043 = | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x30            ???? Where           ma
044 = | RESULT float32  3:1x30x30            IHHH Softmax         so | RESULT float32  3:1x30x30            IHHH Softmax         so
045 = | RESULT float32  3:1x30x16            TVUU MatMul          ma | RESULT float32  3:1x30x16            TVUU MatMul          ma
046 = | RESULT float32  3:1x30x16            NBDY MatMul          li | RESULT float32  3:1x30x16            NBDY MatMul          li
047 = | RESULT float32  3:1x30x16            VASX MatMul          li | RESULT float32  3:1x30x16            VASX MatMul          li
048 ~ | RESULT float32  3:1x30x16            IZSY MatMul          li | RESULT float32  3:1x30x30            RCYW FusedMatMul     _o
049 ~ | RESULT float32  3:1x16x30            AOUZ Transpose       tr | RESULT float32  3:1x30x16            IZSY MatMul          li
050 ~ | RESULT float32  3:1x30x30            QLQJ MatMul          ma | RESULT bool     2:30x30              HLZC Equal           eq
051 ~ | RESULT float32  3:1x30x30            RCYW Mul             _o | RESULT float32  3:1x30x30            ???? Where           ma
052 ~ | RESULT float32  2:30x30              KGSP Slice           sl | RESULT float32  3:1x30x30            IGHH Softmax         so
053 - | RESULT bool     2:30x30              HLZC Equal           eq |
054 ~ | RESULT float32  3:1x30x30            ???? Where           ma | RESULT float32  3:1x30x16            QCGG MatMul          ma
055 ~ | RESULT float32  3:1x30x30            IGHH Softmax         so | RESULT float32  3:1x30x32            IYZC Concat          ca
056 ~ | RESULT float32  3:1x30x16            QCGG MatMul          ma | RESULT float32  3:1x30x16            AAAC MatMul          _o
057 ~ | RESULT float32  3:1x30x32            IYZC Concat          ca | RESULT float32  3:1x30x16            ZYAA Add             li
058 ~ | RESULT float32  3:1x30x16            AAAC MatMul          _o | RESULT float32  3:1x30x16            CYAA SkipLayerNormal _o
059 ~ | RESULT float32  3:1x30x16            ZYAA Add             li | RESULT float32  3:1x30x1             ABAA SkipLayerNormal un
060 ~ | RESULT float32  3:1x30x16            GQNC Add             ad | RESULT float32  3:1x30x1             GFGE SkipLayerNormal un
061 ~ | RESULT float32  3:1x30x16            CYAA LayerNormalizat _o | RESULT float32  3:1x30x16            GQNC SkipLayerNormal ad
062 = | RESULT float32  3:1x30x128           LJDZ MatMul          _o | RESULT float32  3:1x30x128           LJDZ MatMul          _o
063 = | RESULT float32  3:1x30x128           JHBX Add             li | RESULT float32  3:1x30x128           JHBX Add             li
064 = | RESULT float32  3:1x30x128           NHIT Relu            re | RESULT float32  3:1x30x128           NHIT Relu            re
065 = | RESULT float32  3:1x30x16            CDCA MatMul          _o | RESULT float32  3:1x30x16            CDCA MatMul          _o
066 = | RESULT float32  3:1x30x16            DFEC Add             li | RESULT float32  3:1x30x16            DFEC Add             li
067 = | RESULT float32  3:1x30x16            KVRE Add             ou | RESULT float32  3:1x30x16            KVRE Add             ou
068 = | OUTPUT float32  3:1x30x16            KVRE                 ou | OUTPUT float32  3:1x30x16            KVRE                 ou

The conversion should handle dynamic shapes as well as the input sequence can be of any length. But that’s a topic for another example.

Total running time of the script: (0 minutes 2.474 seconds)

Related examples

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and Phi-2

to_onnx and Phi-2

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

Gallery generated by Sphinx-Gallery