Phi

Phi

<<<

import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import PhiConfig
from transformers.models.phi.modeling_phi import PhiModel
from experimental_experiment.torch_interpreter import to_onnx


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = PhiConfig(
    hidden_size=32,
    num_hidden_layers=2,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=512,
    num_attention_heads=2,
    num_key_value_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad():

    model = PhiModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    onx = to_onnx(model, (input_ids, input_mask))
    print(onnx_simple_text_plot(onx))

>>>

    opset: domain='' version=18
    input: name='input_ids' type=dtype('int64') shape=[2, 1024]
    input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
    init: name='p_embed_tokens_weight' type=dtype('float32') shape=(1024, 32)
    init: name='p_layers_0_input_layernorm_weight' type=dtype('float32') shape=(32,)
    init: name='p_layers_0_input_layernorm_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_0_self_attn_q_proj_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_0_self_attn_q_proj_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_0_self_attn_k_proj_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_0_self_attn_k_proj_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_0_self_attn_v_proj_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_0_self_attn_v_proj_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_0_self_attn_dense_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_0_self_attn_dense_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_0_mlp_fc1_weight' type=dtype('float32') shape=(16, 32)
    init: name='p_layers_0_mlp_fc1_bias' type=dtype('float32') shape=(16,)
    init: name='p_layers_0_mlp_fc2_weight' type=dtype('float32') shape=(32, 16)
    init: name='p_layers_0_mlp_fc2_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_input_layernorm_weight' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_input_layernorm_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_self_attn_q_proj_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_1_self_attn_q_proj_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_self_attn_k_proj_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_1_self_attn_k_proj_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_self_attn_v_proj_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_1_self_attn_v_proj_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_self_attn_dense_weight' type=dtype('float32') shape=(32, 32)
    init: name='p_layers_1_self_attn_dense_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_1_mlp_fc1_weight' type=dtype('float32') shape=(16, 32)
    init: name='p_layers_1_mlp_fc1_bias' type=dtype('float32') shape=(16,)
    init: name='p_layers_1_mlp_fc2_weight' type=dtype('float32') shape=(32, 16)
    init: name='p_layers_1_mlp_fc2_bias' type=dtype('float32') shape=(32,)
    init: name='p_final_layernorm_weight' type=dtype('float32') shape=(32,)
    init: name='p_final_layernorm_bias' type=dtype('float32') shape=(32,)
    init: name='b_layers_0_self_attn_rotary_emb_cos_cached' type=dtype('float32') shape=(1024, 8)
    init: name='b_layers_0_self_attn_rotary_emb_sin_cached' type=dtype('float32') shape=(1024, 8)
    init: name='b_layers_1_self_attn_rotary_emb_cos_cached' type=dtype('float32') shape=(1024, 8)
    init: name='b_layers_1_self_attn_rotary_emb_sin_cached' type=dtype('float32') shape=(1024, 8)
    init: name='init7_s_0' type=dtype('int64') shape=() -- array([0])
    init: name='init7_s_1024' type=dtype('int64') shape=() -- array([1024])
    init: name='init7_s_1' type=dtype('int64') shape=() -- array([1])
    init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
    init: name='init7_s2_1024_1024' type=dtype('int64') shape=(2,) -- array([1024, 1024])
    init: name='init7_s2_1024_1' type=dtype('int64') shape=(2,) -- array([1024,    1])
    init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
    init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
    init: name='init7_s4_2_1_1024_1024' type=dtype('int64') shape=(4,) -- array([   2,    1, 1024, 1024])
    init: name='init1_s_' type=dtype('float32') shape=() -- array([1.], dtype=float32)
    init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
    init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
    init: name='init7_s2_2048_32' type=dtype('int64') shape=(2,) -- array([2048,   32])
    init: name='init7_s3_2_1024_32' type=dtype('int64') shape=(3,) -- array([   2, 1024,   32])
    init: name='init7_s4_2_1024_2_16' type=dtype('int64') shape=(4,) -- array([   2, 1024,    2,   16])
    init: name='init7_s1_1024' type=dtype('int64') shape=(1,) -- array([1024])
    init: name='init7_s3_2_1024_16' type=dtype('int64') shape=(3,) -- array([   2, 1024,   16])
    init: name='init1_s_3' type=dtype('float32') shape=() -- array([0.5], dtype=float32)
    init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
    init: name='init1_s_4' type=dtype('float32') shape=() -- array([0.045], dtype=float32)
    init: name='init1_s_5' type=dtype('float32') shape=() -- array([0.798], dtype=float32)
    init: name='init1_s_6' type=dtype('float32') shape=() -- array([1.], dtype=float32)
    init: name='init7_s2_2048_16' type=dtype('int64') shape=(2,) -- array([2048,   16])
    init: name='init1_s_8' type=dtype('float32') shape=() -- array([0.5], dtype=float32)
    init: name='init1_s1_5' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
    init: name='init1_s_9' type=dtype('float32') shape=() -- array([0.045], dtype=float32)
    init: name='init1_s_10' type=dtype('float32') shape=() -- array([0.798], dtype=float32)
    init: name='init1_s_11' type=dtype('float32') shape=() -- array([1.], dtype=float32)
    init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
    init: name='init7_s2_1_2' type=dtype('int64') shape=(2,) -- array([1, 2])
    init: name='init1_s_12' type=dtype('float32') shape=() -- array([0.25], dtype=float32)
    init: name='init1_s_13' type=dtype('float32') shape=() -- array([0.25], dtype=float32)
    init: name='init7_s2_8_8' type=dtype('int64') shape=(2,) -- array([8, 8])
    init: name='init7_s2_4_4' type=dtype('int64') shape=(2,) -- array([4, 4])
    ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
      Unsqueeze(arange, init7_s1_0) -> unsqueeze
    Gather(p_embed_tokens_weight, input_ids) -> embedding
      LayerNormalization(embedding, p_layers_0_input_layernorm_weight, p_layers_0_input_layernorm_bias, axis=-1, epsilon=0.00) -> native_layer_norm#0, native_layer_norm#1, native_layer_norm#2
        Reshape(native_layer_norm#0, init7_s2_2048_32) -> view_1
          Gemm(view_1, p_layers_0_mlp_fc1_weight, p_layers_0_mlp_fc1_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_4
            Reshape(addmm_4, init7_s3_2_1024_16) -> view_19
              Mul(view_19, init1_s_3) -> _onx_mul0
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
      Add(arange_1, init7_s_1) -> add
        Reshape(add, init7_s2_1024_1) -> view
      Less(arange_1, view) -> lt
      Where(lt, init1_s1_, full) -> _onx_where0
        Unsqueeze(_onx_where0, init7_s2_0_1) -> unsqueeze_2
          Expand(unsqueeze_2, init7_s4_2_1_1024_1024) -> expand
    Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_4
      Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
        Sub(init1_s_, expand_1) -> rsub
          Cast(rsub, to=9) -> _to_copy_2
          Where(_to_copy_2, init1_s1_2, rsub) -> _onx_where02
            Cast(_onx_where02, to=9) -> _to_copy_4
            Where(_to_copy_4, init1_s1_3, expand) -> _onx_where03
          Gemm(view_1, p_layers_0_self_attn_v_proj_weight, p_layers_0_self_attn_v_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_2
            Reshape(addmm_2, init7_s4_2_1024_2_16) -> view_9
              Transpose(view_9, perm=[0,2,1,3]) -> output_2
          Gemm(view_1, p_layers_0_self_attn_k_proj_weight, p_layers_0_self_attn_k_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_1
            Reshape(addmm_1, init7_s4_2_1024_2_16) -> view_8
              Transpose(view_8, perm=[0,2,1,3]) -> transpose_1
                Split(transpose_1, init7_s2_8_8, axis=3) -> slice_9, slice_10
                  Split(slice_9, init7_s2_4_4, axis=3) -> slice_13, slice_14
                    Neg(slice_14) -> neg_1
                    Concat(neg_1, slice_13, axis=-1) -> cat_1
          Gemm(view_1, p_layers_0_self_attn_q_proj_weight, p_layers_0_self_attn_q_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
            Reshape(addmm, init7_s4_2_1024_2_16) -> view_7
              Transpose(view_7, perm=[0,2,1,3]) -> transpose
                Split(transpose, init7_s2_8_8, axis=3) -> slice_7, slice_8
                  Split(slice_7, init7_s2_4_4, axis=3) -> slice_11, slice_12
                    Neg(slice_12) -> neg
                    Concat(neg, slice_11, axis=-1) -> cat
    Slice(b_layers_0_self_attn_rotary_emb_cos_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_5
      Gather(slice_5, unsqueeze, axis=0) -> index
        Unsqueeze(index, init7_s1_1) -> unsqueeze_5
          Mul(slice_7, unsqueeze_5) -> mul
    Slice(b_layers_0_self_attn_rotary_emb_sin_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_6
      Gather(slice_6, unsqueeze, axis=0) -> index_1
        Unsqueeze(index_1, init7_s1_1) -> unsqueeze_6
          Mul(cat, unsqueeze_6) -> mul_1
            Add(mul, mul_1) -> add_1
              Concat(add_1, slice_8, axis=-1) -> cat_2
          Mul(slice_9, unsqueeze_5) -> mul_2
    Mul(cat_1, unsqueeze_6) -> mul_3
      Add(mul_2, mul_3) -> add_2
        Concat(add_2, slice_10, axis=-1) -> output_1
          Transpose(output_1, perm=[0,1,3,2]) -> transpose_3
            MatMul(cat_2, transpose_3) -> view_12
              Mul(view_12, init1_s_12) -> div
              Add(div, _onx_where03) -> add_3
                Softmax(add_3, axis=-1) -> _softmax
                MatMul(_softmax, output_2) -> view_14
                  Transpose(view_14, perm=[0,2,1,3]) -> transpose_4
                    Reshape(transpose_4, init7_s2_2048_32) -> view_16
                      Gemm(view_16, p_layers_0_self_attn_dense_weight, p_layers_0_self_attn_dense_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_3
              Pow(view_19, init1_s1_4) -> pow_1
                Mul(pow_1, init1_s_4) -> _onx_mul02
              Add(view_19, _onx_mul02) -> add_4
                Mul(add_4, init1_s_5) -> _onx_mul03
                  Tanh(_onx_mul03) -> tanh
                    Add(tanh, init1_s_6) -> add_5
                Mul(_onx_mul0, add_5) -> mul_7
                  Reshape(mul_7, init7_s2_2048_16) -> view_20
                    Gemm(view_20, p_layers_0_mlp_fc2_weight, p_layers_0_mlp_fc2_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_5
                      Add(addmm_3, addmm_5) -> Reshape2Of3PatternL_add_6
                        Reshape(Reshape2Of3PatternL_add_6, init7_s3_2_1024_32) -> add_6
      Add(add_6, embedding) -> add_7
        LayerNormalization(add_7, p_layers_1_input_layernorm_weight, p_layers_1_input_layernorm_bias, axis=-1, epsilon=0.00) -> native_layer_norm_1#0, native_layer_norm_1#1, native_layer_norm_1#2
          Reshape(native_layer_norm_1#0, init7_s2_2048_32) -> view_22
            Gemm(view_22, p_layers_1_mlp_fc1_weight, p_layers_1_mlp_fc1_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_10
              Reshape(addmm_10, init7_s3_2_1024_16) -> view_40
                Mul(view_40, init1_s_8) -> _onx_mul04
            Gemm(view_22, p_layers_1_self_attn_v_proj_weight, p_layers_1_self_attn_v_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_8
              Reshape(addmm_8, init7_s4_2_1024_2_16) -> view_30
                Transpose(view_30, perm=[0,2,1,3]) -> output_4
            Gemm(view_22, p_layers_1_self_attn_k_proj_weight, p_layers_1_self_attn_k_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_7
              Reshape(addmm_7, init7_s4_2_1024_2_16) -> view_29
                Transpose(view_29, perm=[0,2,1,3]) -> transpose_6
                  Split(transpose_6, init7_s2_8_8, axis=3) -> slice_19, slice_20
                    Split(slice_19, init7_s2_4_4, axis=3) -> slice_23, slice_24
                      Neg(slice_24) -> neg_3
                      Concat(neg_3, slice_23, axis=-1) -> cat_5
            Gemm(view_22, p_layers_1_self_attn_q_proj_weight, p_layers_1_self_attn_q_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_6
              Reshape(addmm_6, init7_s4_2_1024_2_16) -> view_28
                Transpose(view_28, perm=[0,2,1,3]) -> transpose_5
                  Split(transpose_5, init7_s2_8_8, axis=3) -> slice_17, slice_18
                    Split(slice_17, init7_s2_4_4, axis=3) -> slice_21, slice_22
                      Neg(slice_22) -> neg_2
                      Concat(neg_2, slice_21, axis=-1) -> cat_4
    Slice(b_layers_1_self_attn_rotary_emb_cos_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_15
      Gather(slice_15, unsqueeze, axis=0) -> index_2
        Unsqueeze(index_2, init7_s1_1) -> unsqueeze_7
          Mul(slice_17, unsqueeze_7) -> mul_8
    Slice(b_layers_1_self_attn_rotary_emb_sin_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_16
      Gather(slice_16, unsqueeze, axis=0) -> index_3
        Unsqueeze(index_3, init7_s1_1) -> unsqueeze_8
          Mul(cat_4, unsqueeze_8) -> mul_9
            Add(mul_8, mul_9) -> add_8
              Concat(add_8, slice_18, axis=-1) -> cat_6
          Mul(slice_19, unsqueeze_7) -> mul_10
    Mul(cat_5, unsqueeze_8) -> mul_11
      Add(mul_10, mul_11) -> add_9
        Concat(add_9, slice_20, axis=-1) -> output_3
          Transpose(output_3, perm=[0,1,3,2]) -> transpose_8
            MatMul(cat_6, transpose_8) -> view_33
              Mul(view_33, init1_s_13) -> div_1
              Add(div_1, _onx_where03) -> add_10
                Softmax(add_10, axis=-1) -> _softmax_1
                  MatMul(_softmax_1, output_4) -> view_35
                    Transpose(view_35, perm=[0,2,1,3]) -> transpose_9
                      Reshape(transpose_9, init7_s2_2048_32) -> view_37
                        Gemm(view_37, p_layers_1_self_attn_dense_weight, p_layers_1_self_attn_dense_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_9
                Pow(view_40, init1_s1_5) -> pow_2
                  Mul(pow_2, init1_s_9) -> _onx_mul05
                Add(view_40, _onx_mul05) -> add_11
                  Mul(add_11, init1_s_10) -> _onx_mul06
                    Tanh(_onx_mul06) -> tanh_1
                      Add(tanh_1, init1_s_11) -> add_12
                  Mul(_onx_mul04, add_12) -> mul_15
                    Reshape(mul_15, init7_s2_2048_16) -> view_41
                      Gemm(view_41, p_layers_1_mlp_fc2_weight, p_layers_1_mlp_fc2_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_11
                        Add(addmm_9, addmm_11) -> Reshape2Of3PatternL_add_13
                          Reshape(Reshape2Of3PatternL_add_13, init7_s3_2_1024_32) -> add_13
        Add(add_13, add_7) -> add_14
          LayerNormalization(add_14, p_final_layernorm_weight, p_final_layernorm_bias, axis=-1, epsilon=0.00) -> output_0, native_layer_norm_2#1, native_layer_norm_2#2
    output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
    output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]