PhiΒΆ

Phi

<<<

import numpy as np
import torch
from transformers import PhiConfig
from transformers.models.phi.modeling_phi import PhiModel
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
from onnx_diagnostic.torch_export_patches import torch_export_patches


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = PhiConfig(
    hidden_size=32,
    num_hidden_layers=2,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=512,
    num_attention_heads=2,
    num_key_value_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad(), torch_export_patches(patch_transformers=True) as modificator:

    model = PhiModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    onx = to_onnx(
        model,
        modificator((input_ids, input_mask)),
        export_options=ExportOptions(decomposition_table="default"),
    )
    print(pretty_onnx(onx))

>>>

    opset: domain='' version=18
    input: name='input_ids' type=dtype('int64') shape=[2, 1024]
    input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
    init: name='c_lifted_tensor_0' type=float32 shape=() -- array([0.], dtype=float32)-- DynamoInterpret.placeholder.0
    init: name='b_rotary_emb_inv_freq' type=float32 shape=(4,) -- array([1.   , 0.1  , 0.01 , 0.001], dtype=float32)-- DynamoInterpret.placeholder.0
    init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s4_2_1_1024_1024' type=int64 shape=(4,) -- array([   2,    1, 1024, 1024])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s1_' type=float32 shape=(1,) -- array([-3.403e+38], dtype=float32)-- Opset.make_node.1/Small
    init: name='init7_s4_2_1024_-1_16' type=int64 shape=(4,) -- array([   2, 1024,   -1,   16])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s1_3' type=int64 shape=(1,) -- array([3])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s1_1024' type=int64 shape=(1,) -- array([1024])     -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init7_s3_2_1024_-1' type=int64 shape=(3,) -- array([   2, 1024,   -1])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s1_4' type=float32 shape=(1,) -- array([3.], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
    init: name='view' type=int64 shape=(2, 1, 1, 1) -- array([0, 1])      -- GraphBuilder.constant_folding.from/fold(arange_2,init7_s4_-1_1_1_1)##arange_2/##init7_s4_-1_1_1_1/Opset.make_node.1/Shape
    init: name='and_1' type=bool shape=(2, 1, 1024, 1024)                 -- GraphBuilder.constant_folding.from/fold(_onx_and_new_ones::RSh1)##_onx_and_new_ones::RSh1/
    init: name='_onx_add_expand_4::RSh-1' type=int64 shape=(2097152,)     -- GraphBuilder.constant_folding.from/fold(_onx_add_expand_4,init7_s1_-1)##_onx_add_expand_4/##init7_s1_-1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='_to_copy_1' type=float32 shape=(1, 1, 1024)               -- GraphBuilder.constant_folding.from/fold(unsqueeze_3)##unsqueeze_3/
    init: name='p_layers_0_self_attn_q_proj_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_0_self_attn_q_proj_weight)##p_layers_0_self_attn_q_proj_weight/DynamoInterpret.placeholder.1/P(layers.0.self_attn.q_proj.weight)
    init: name='p_layers_0_self_attn_k_proj_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_0_self_attn_k_proj_weight)##p_layers_0_self_attn_k_proj_weight/DynamoInterpret.placeholder.1/P(layers.0.self_attn.k_proj.weight)
    init: name='p_layers_0_self_attn_v_proj_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_0_self_attn_v_proj_weight)##p_layers_0_self_attn_v_proj_weight/DynamoInterpret.placeholder.1/P(layers.0.self_attn.v_proj.weight)
    init: name='init1_s_2::RSh1' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='p_layers_0_self_attn_dense_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_0_self_attn_dense_weight)##p_layers_0_self_attn_dense_weight/DynamoInterpret.placeholder.1/P(layers.0.self_attn.dense.weight)
    init: name='p_layers_0_mlp_fc1_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_layers_0_mlp_fc1_weight)##p_layers_0_mlp_fc1_weight/DynamoInterpret.placeholder.1/P(layers.0.mlp.fc1.weight)
    init: name='init1_s_3::RSh1' type=float32 shape=(1,) -- array([0.5], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_3,init7_s1_1)##init1_s_3/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_4::RSh1' type=float32 shape=(1,) -- array([0.045], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_4,init7_s1_1)##init1_s_4/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_5::RSh1' type=float32 shape=(1,) -- array([0.798], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_5,init7_s1_1)##init1_s_5/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_::RSh13' type=float32 shape=(1,) -- array([1.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='p_layers_0_mlp_fc2_weight::T10' type=float32 shape=(16, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_0_mlp_fc2_weight)##p_layers_0_mlp_fc2_weight/DynamoInterpret.placeholder.1/P(layers.0.mlp.fc2.weight)
    init: name='p_layers_1_self_attn_q_proj_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_1_self_attn_q_proj_weight)##p_layers_1_self_attn_q_proj_weight/DynamoInterpret.placeholder.1/P(layers.1.self_attn.q_proj.weight)
    init: name='p_layers_1_self_attn_k_proj_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_1_self_attn_k_proj_weight)##p_layers_1_self_attn_k_proj_weight/DynamoInterpret.placeholder.1/P(layers.1.self_attn.k_proj.weight)
    init: name='p_layers_1_self_attn_v_proj_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_1_self_attn_v_proj_weight)##p_layers_1_self_attn_v_proj_weight/DynamoInterpret.placeholder.1/P(layers.1.self_attn.v_proj.weight)
    init: name='init1_s_2::RSh12' type=float32 shape=(1,) -- array([0.25], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_2,init7_s1_1)##init1_s_2/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='p_layers_1_self_attn_dense_weight::T10' type=float32 shape=(32, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_1_self_attn_dense_weight)##p_layers_1_self_attn_dense_weight/DynamoInterpret.placeholder.1/P(layers.1.self_attn.dense.weight)
    init: name='p_layers_1_mlp_fc1_weight::T10' type=float32 shape=(32, 16)-- GraphBuilder.constant_folding.from/fold(p_layers_1_mlp_fc1_weight)##p_layers_1_mlp_fc1_weight/DynamoInterpret.placeholder.1/P(layers.1.mlp.fc1.weight)
    init: name='init1_s_3::RSh12' type=float32 shape=(1,) -- array([0.5], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_3,init7_s1_1)##init1_s_3/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_4::RSh12' type=float32 shape=(1,) -- array([0.045], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_4,init7_s1_1)##init1_s_4/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_5::RSh12' type=float32 shape=(1,) -- array([0.798], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_5,init7_s1_1)##init1_s_5/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='init1_s_::RSh14' type=float32 shape=(1,) -- array([1.], dtype=float32)-- GraphBuilder.constant_folding.from/fold(init1_s_,init7_s1_1)##init1_s_/shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0##init7_s1_1/Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
    init: name='p_layers_1_mlp_fc2_weight::T10' type=float32 shape=(16, 32)-- GraphBuilder.constant_folding.from/fold(p_layers_1_mlp_fc2_weight)##p_layers_1_mlp_fc2_weight/DynamoInterpret.placeholder.1/P(layers.1.mlp.fc2.weight)
    init: name='init7_s2_0_2' type=int64 shape=(2,) -- array([0, 2])      -- UnsqueezeUnsqueezePattern.apply.new_axis
    init: name='init1_s32_' type=float32 shape=(32,)                      -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
    init: name='init1_s32_2' type=float32 shape=(32,)                     -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
    init: name='init7_s2_8_8' type=int64 shape=(2,) -- array([8, 8])      -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
    init: name='init7_s2_4_4' type=int64 shape=(2,) -- array([4, 4])      -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
    init: name='embed_tokens.weight' type=float32 shape=(1024, 32)        -- DynamoInterpret.placeholder.1/P(embed_tokens.weight)
    Cast(attention_mask, to=9) -> _to_copy
      Reshape(_to_copy, init7_s1_-1) -> _to_copy::RSh-1
        Gather(_to_copy::RSh-1, _onx_add_expand_4::RSh-1) -> _onx_gather__to_copy::RSh-1
    Gather(embed_tokens.weight, input_ids) -> embedding
      LayerNormalization(embedding, init1_s32_, init1_s32_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_clone_1
        MatMul(_onx_div_sub_clone_1, p_layers_0_self_attn_q_proj_weight::T10) -> _onx_matmul_layer_norm
          Reshape(_onx_matmul_layer_norm, init7_s4_2_1024_-1_16) -> view_5
            Transpose(view_5, perm=[0,2,1,3]) -> transpose_1
              Split(transpose_1, init7_s2_8_8, axis=3) -> slice_4, slice_5
                Split(slice_4, init7_s2_4_4, axis=3) -> slice_8, slice_9
                  Neg(slice_9) -> neg
                  Concat(neg, slice_8, axis=-1) -> cat_1
    Expand(view, init7_s4_2_1_1024_1024) -> expand
      Shape(expand) -> expand::Shape:
        Reshape(_onx_gather__to_copy::RSh-1, expand::Shape:) -> index
          And(and_1, index) -> and_2
            Where(and_2, c_lifted_tensor_0, init1_s1_) -> where
              Slice(where, init7_s1_0, init7_s1_1024, init7_s1_3) -> slice_15
    Unsqueeze(b_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_2
      MatMul(unsqueeze_2, _to_copy_1) -> matmul
        Transpose(matmul, perm=[0,2,1]) -> transpose
          Concat(transpose, transpose, axis=-1) -> cat
            Cos(cat) -> cos
              Unsqueeze(cos, init7_s1_1) -> unsqueeze_4
                Mul(slice_4, unsqueeze_4) -> mul_2
            Sin(cat) -> sin
              Unsqueeze(sin, init7_s1_1) -> unsqueeze_5
                Mul(cat_1, unsqueeze_5) -> mul_3
                  Add(mul_2, mul_3) -> add_1
                Concat(add_1, slice_5, axis=-1) -> cat_3
        MatMul(_onx_div_sub_clone_1, p_layers_0_self_attn_k_proj_weight::T10) -> _onx_matmul_layer_norm2
          Reshape(_onx_matmul_layer_norm2, init7_s4_2_1024_-1_16) -> view_6
            Transpose(view_6, perm=[0,2,1,3]) -> transpose_2
              Split(transpose_2, init7_s2_8_8, axis=3) -> slice_6, slice_7
                Split(slice_6, init7_s2_4_4, axis=3) -> slice_10, slice_11
                  Neg(slice_11) -> neg_1
                  Concat(neg_1, slice_10, axis=-1) -> cat_2
                Mul(cat_2, unsqueeze_5) -> mul_5
        MatMul(_onx_div_sub_clone_1, p_layers_0_self_attn_v_proj_weight::T10) -> _onx_matmul_layer_norm3
          Reshape(_onx_matmul_layer_norm3, init7_s4_2_1024_-1_16) -> view_7
            Transpose(view_7, perm=[0,2,1,3]) -> output_3
    Mul(slice_6, unsqueeze_4) -> mul_4
      Add(mul_4, mul_5) -> add_2
        Concat(add_2, slice_7, axis=-1) -> output_1
          Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
            MatMul(cat_3, transpose_4) -> matmul_1
              Mul(matmul_1, init1_s_2::RSh1) -> _onx_mul_matmul_1
                Add(_onx_mul_matmul_1, slice_15) -> add_3
                  Softmax(add_3, axis=-1) -> softmax
              MatMul(softmax, output_3) -> matmul_2
                Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
                  Reshape(transpose_5, init7_s3_2_1024_-1) -> view_8
                    MatMul(view_8, p_layers_0_self_attn_dense_weight::T10) -> _onx_matmul_view_8
        MatMul(_onx_div_sub_clone_1, p_layers_0_mlp_fc1_weight::T10) -> _onx_matmul_layer_norm4
          Mul(_onx_matmul_layer_norm4, init1_s_3::RSh1) -> _onx_mul_linear_4
    Pow(_onx_matmul_layer_norm4, init1_s1_4) -> pow_1
      Mul(pow_1, init1_s_4::RSh1) -> _onx_mul_pow_1
        Add(_onx_matmul_layer_norm4, _onx_mul_pow_1) -> add_4
          Mul(add_4, init1_s_5::RSh1) -> _onx_mul_add_4
            Tanh(_onx_mul_add_4) -> tanh
              Add(tanh, init1_s_::RSh13) -> add_5
            Mul(_onx_mul_linear_4, add_5) -> mul_10
              MatMul(mul_10, p_layers_0_mlp_fc2_weight::T10) -> _onx_matmul_mul_10
                Add(_onx_matmul_view_8, _onx_matmul_mul_10) -> add_6
      Add(add_6, embedding) -> add_7
        LayerNormalization(add_7, init1_s32_, init1_s32_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div_sub_add_7
          MatMul(_onx_div_sub_add_7, p_layers_1_self_attn_q_proj_weight::T10) -> _onx_matmul_layer_norm_1
            Reshape(_onx_matmul_layer_norm_1, init7_s4_2_1024_-1_16) -> view_9
              Transpose(view_9, perm=[0,2,1,3]) -> transpose_6
                Split(transpose_6, init7_s2_8_8, axis=3) -> slice_16, slice_17
                  Split(slice_16, init7_s2_4_4, axis=3) -> slice_20, slice_21
                    Neg(slice_21) -> neg_2
                    Concat(neg_2, slice_20, axis=-1) -> cat_5
                Mul(cat_5, unsqueeze_5) -> mul_12
          MatMul(_onx_div_sub_add_7, p_layers_1_self_attn_k_proj_weight::T10) -> _onx_matmul_layer_norm_12
            Reshape(_onx_matmul_layer_norm_12, init7_s4_2_1024_-1_16) -> view_10
              Transpose(view_10, perm=[0,2,1,3]) -> transpose_7
                Split(transpose_7, init7_s2_8_8, axis=3) -> slice_18, slice_19
                  Split(slice_18, init7_s2_4_4, axis=3) -> slice_22, slice_23
                    Neg(slice_23) -> neg_3
                    Concat(neg_3, slice_22, axis=-1) -> cat_6
                Mul(cat_6, unsqueeze_5) -> mul_14
          MatMul(_onx_div_sub_add_7, p_layers_1_self_attn_v_proj_weight::T10) -> _onx_matmul_layer_norm_13
            Reshape(_onx_matmul_layer_norm_13, init7_s4_2_1024_-1_16) -> view_11
              Transpose(view_11, perm=[0,2,1,3]) -> output_4
    Mul(slice_16, unsqueeze_4) -> mul_11
      Add(mul_11, mul_12) -> add_8
        Concat(add_8, slice_17, axis=-1) -> cat_7
    Mul(slice_18, unsqueeze_4) -> mul_13
      Add(mul_13, mul_14) -> add_9
        Concat(add_9, slice_19, axis=-1) -> output_2
          Transpose(output_2, perm=[0,1,3,2]) -> transpose_9
          MatMul(cat_7, transpose_9) -> matmul_3
            Mul(matmul_3, init1_s_2::RSh12) -> _onx_mul_matmul_3
              Add(_onx_mul_matmul_3, slice_15) -> add_10
                Softmax(add_10, axis=-1) -> softmax_1
                MatMul(softmax_1, output_4) -> matmul_4
                  Transpose(matmul_4, perm=[0,2,1,3]) -> transpose_10
                    Reshape(transpose_10, init7_s3_2_1024_-1) -> view_12
                      MatMul(view_12, p_layers_1_self_attn_dense_weight::T10) -> _onx_matmul_view_12
          MatMul(_onx_div_sub_add_7, p_layers_1_mlp_fc1_weight::T10) -> _onx_matmul_layer_norm_14
            Mul(_onx_matmul_layer_norm_14, init1_s_3::RSh12) -> _onx_mul_linear_10
    Pow(_onx_matmul_layer_norm_14, init1_s1_4) -> pow_2
      Mul(pow_2, init1_s_4::RSh12) -> _onx_mul_pow_2
        Add(_onx_matmul_layer_norm_14, _onx_mul_pow_2) -> add_11
          Mul(add_11, init1_s_5::RSh12) -> _onx_mul_add_11
            Tanh(_onx_mul_add_11) -> tanh_1
              Add(tanh_1, init1_s_::RSh14) -> add_12
              Mul(_onx_mul_linear_10, add_12) -> mul_19
                MatMul(mul_19, p_layers_1_mlp_fc2_weight::T10) -> _onx_matmul_mul_19
                  Add(_onx_matmul_view_12, _onx_matmul_mul_19) -> add_13
        Add(add_13, add_7) -> add_14
          LayerNormalization(add_14, init1_s32_, init1_s32_2, axis=-1, epsilon=0.00, stash_type=1) -> output_0
    output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
    output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]