Phi

Phi

<<<

import numpy as np
import torch
from transformers import PhiConfig
from transformers.models.phi.modeling_phi import PhiModel
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = PhiConfig(
    hidden_size=32,
    num_hidden_layers=2,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=512,
    num_attention_heads=2,
    num_key_value_heads=2,
)
config._attn_implementation = "eager"

with torch.no_grad():

    model = PhiModel(config)

    batch, seq, vocab_size = 2, 1024, 1024

    input_ids = ids_tensor([batch, seq], vocab_size)
    input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

    model(input_ids, input_mask)

    onx = to_onnx(model, (input_ids, input_mask))
    print(pretty_onnx(onx))

>>>

    opset: domain='' version=18
    opset: domain='local_functions' version=1
    doc_string: large_model=False, inline=False, external_threshold=102...
    input: name='input_ids' type=dtype('int64') shape=[2, 1024]
    input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
    init: name='b_rotary_emb_inv_freq' type=dtype('float32') shape=(4,) -- array([1.   , 0.1  , 0.01 , 0.001], dtype=float32)
    init: name='init7_s_0' type=dtype('int64') shape=() -- array([0])
    init: name='init7_s_1024' type=dtype('int64') shape=() -- array([1024])
    init: name='init7_s_1' type=dtype('int64') shape=() -- array([1])
    init: name='init7_s2_1024_1024' type=dtype('int64') shape=(2,) -- array([1024, 1024])
    init: name='init7_s2_-1_1' type=dtype('int64') shape=(2,) -- array([-1,  1])
    init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
    init: name='init7_s4_2_1_1024_1024' type=dtype('int64') shape=(4,) -- array([   2,    1, 1024, 1024])
    init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
    init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
    init: name='init1_s_2' type=dtype('float32') shape=() -- array([1.], dtype=float32)
    init: name='init7_s4_2_1024_2_16' type=dtype('int64') shape=(4,) -- array([   2, 1024,    2,   16])
    init: name='init1_s_3' type=dtype('float32') shape=() -- array([4.], dtype=float32)
    init: name='init7_s3_2_1024_32' type=dtype('int64') shape=(3,) -- array([   2, 1024,   32])
    init: name='init1_s_4' type=dtype('float32') shape=() -- array([0.5], dtype=float32)
    init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
    init: name='init1_s_5' type=dtype('float32') shape=() -- array([0.045], dtype=float32)
    init: name='init1_s_6' type=dtype('float32') shape=() -- array([0.798], dtype=float32)
    init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
    init: name='init7_s2_1_2' type=dtype('int64') shape=(2,) -- array([1, 2])
    init: name='init7_s2_0_2' type=dtype('int64') shape=(2,) -- array([0, 2])
    init: name='init1_s32_' type=dtype('float32') shape=(32,)
    init: name='init1_s32_2' type=dtype('float32') shape=(32,)
    init: name='init7_s2_8_8' type=dtype('int64') shape=(2,) -- array([8, 8])
    init: name='init7_s2_4_4' type=dtype('int64') shape=(2,) -- array([4, 4])
    init: name='embed_tokens.weight' type=dtype('float32') shape=(1024, 32)
    init: name='layers.0.input_layernorm.weight' type=dtype('float32') shape=(32,)
    init: name='layers.0.input_layernorm.bias' type=dtype('float32') shape=(32,)
    init: name='layers.0.self_attn.q_proj.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.0.self_attn.q_proj.bias' type=dtype('float32') shape=(32,)
    init: name='layers.0.self_attn.k_proj.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.0.self_attn.k_proj.bias' type=dtype('float32') shape=(32,)
    init: name='layers.0.self_attn.v_proj.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.0.self_attn.v_proj.bias' type=dtype('float32') shape=(32,)
    init: name='layers.0.self_attn.dense.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.0.self_attn.dense.bias' type=dtype('float32') shape=(32,)
    init: name='layers.0.mlp.fc1.weight' type=dtype('float32') shape=(16, 32)
    init: name='layers.0.mlp.fc1.bias' type=dtype('float32') shape=(16,)
    init: name='layers.0.mlp.fc2.weight' type=dtype('float32') shape=(32, 16)
    init: name='layers.0.mlp.fc2.bias' type=dtype('float32') shape=(32,)
    init: name='layers.1.input_layernorm.weight' type=dtype('float32') shape=(32,)
    init: name='layers.1.input_layernorm.bias' type=dtype('float32') shape=(32,)
    init: name='layers.1.self_attn.q_proj.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.1.self_attn.q_proj.bias' type=dtype('float32') shape=(32,)
    init: name='layers.1.self_attn.k_proj.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.1.self_attn.k_proj.bias' type=dtype('float32') shape=(32,)
    init: name='layers.1.self_attn.v_proj.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.1.self_attn.v_proj.bias' type=dtype('float32') shape=(32,)
    init: name='layers.1.self_attn.dense.weight' type=dtype('float32') shape=(32, 32)
    init: name='layers.1.self_attn.dense.bias' type=dtype('float32') shape=(32,)
    init: name='layers.1.mlp.fc1.weight' type=dtype('float32') shape=(16, 32)
    init: name='layers.1.mlp.fc1.bias' type=dtype('float32') shape=(16,)
    init: name='layers.1.mlp.fc2.weight' type=dtype('float32') shape=(32, 16)
    init: name='layers.1.mlp.fc2.bias' type=dtype('float32') shape=(32,)
    init: name='final_layernorm.weight' type=dtype('float32') shape=(32,)
    init: name='final_layernorm.bias' type=dtype('float32') shape=(32,)
    ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
      Trilu(full, init7_s_1, upper=1) -> triu
    Gather(embed_tokens.weight, input_ids) -> embedding
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
      Unsqueeze(arange, init7_s2_0_1) -> unsqueeze_9
        Cast(unsqueeze_9, to=1) -> _to_copy_1
    Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
    Reshape(arange, init7_s2_-1_1) -> view
      Greater(arange_1, view) -> gt
        Cast(gt, to=1) -> _onx_cast0
        Mul(triu, _onx_cast0) -> _onx_mul0
          Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
            Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
    Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
      Add(expand_1, unsqueeze_6) -> add
    Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
      Equal(add, _onx_reshape0) -> eq
        Where(eq, init1_s1_, expand_1) -> _onx_where0
    Unsqueeze(b_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
      submod_3[local_functions](unsqueeze_8, _to_copy_1) -> wrap_with_autocast#0, wrap_with_autocast#1
        Unsqueeze(wrap_with_autocast#0, init7_s1_1) -> unsqueeze_10
    Unsqueeze(wrap_with_autocast#1, init7_s1_1) -> unsqueeze_11
    Mul(init1_s32_, layers.0.input_layernorm.weight) -> LayerNormalizationScalePattern_init1_s32_
    Mul(layers.0.input_layernorm.weight, init1_s32_2) -> LayerNormalizationScalePattern_init1_s32_2
      Add(LayerNormalizationScalePattern_init1_s32_2, layers.0.input_layernorm.bias) -> LayerNormalizationScalePattern_init1_s32_3
      LayerNormalization(embedding, LayerNormalizationScalePattern_init1_s32_, LayerNormalizationScalePattern_init1_s32_3, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add02
    Transpose(layers.0.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose0
      MatMul(_onx_add02, _onx_transpose0) -> _onx_matmul0
        Add(_onx_matmul0, layers.0.self_attn.q_proj.bias) -> linear
          Reshape(linear, init7_s4_2_1024_2_16) -> view_1
            Transpose(view_1, perm=[0,2,1,3]) -> transpose_1
              Split(transpose_1, init7_s2_8_8, axis=3) -> slice_24, slice_25
                Split(slice_24, init7_s2_4_4, axis=3) -> slice_28, slice_29
                  Neg(slice_29) -> neg
                  Concat(neg, slice_28, axis=-1) -> cat_1
      Mul(cat_1, unsqueeze_11) -> mul_4
    Transpose(layers.0.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose02
      MatMul(_onx_add02, _onx_transpose02) -> _onx_matmul02
        Add(_onx_matmul02, layers.0.self_attn.k_proj.bias) -> linear_1
          Reshape(linear_1, init7_s4_2_1024_2_16) -> view_2
            Transpose(view_2, perm=[0,2,1,3]) -> transpose_2
              Split(transpose_2, init7_s2_8_8, axis=3) -> slice_26, slice_27
                Split(slice_26, init7_s2_4_4, axis=3) -> slice_30, slice_31
                  Neg(slice_31) -> neg_1
                  Concat(neg_1, slice_30, axis=-1) -> cat_2
      Mul(cat_2, unsqueeze_11) -> mul_6
    Transpose(layers.0.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose03
      MatMul(_onx_add02, _onx_transpose03) -> _onx_matmul03
        Add(_onx_matmul03, layers.0.self_attn.v_proj.bias) -> linear_2
          Reshape(linear_2, init7_s4_2_1024_2_16) -> view_3
            Transpose(view_3, perm=[0,2,1,3]) -> output_2
          Mul(slice_24, unsqueeze_10) -> mul_3
        Add(mul_3, mul_4) -> add_1
          Concat(add_1, slice_25, axis=-1) -> cat_3
    Mul(slice_26, unsqueeze_10) -> mul_5
      Add(mul_5, mul_6) -> add_2
        Concat(add_2, slice_27, axis=-1) -> output_1
          Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
            MatMul(cat_3, transpose_4) -> matmul_1
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape04
      Div(matmul_1, _onx_reshape04) -> div
        Add(div, _onx_where0) -> add_3
          Softmax(add_3, axis=-1) -> softmax
            MatMul(softmax, output_2) -> matmul_2
              Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
                Reshape(transpose_5, init7_s3_2_1024_32) -> view_4
    Transpose(layers.0.self_attn.dense.weight, perm=[1,0]) -> _onx_transpose04
      MatMul(view_4, _onx_transpose04) -> _onx_matmul04
        Add(_onx_matmul04, layers.0.self_attn.dense.bias) -> linear_3
    Transpose(layers.0.mlp.fc1.weight, perm=[1,0]) -> _onx_transpose05
      MatMul(_onx_add02, _onx_transpose05) -> _onx_matmul05
        Add(_onx_matmul05, layers.0.mlp.fc1.bias) -> linear_4
          Pow(linear_4, init1_s1_4) -> pow_1
    Reshape(init1_s_4, init7_s1_1) -> _onx_reshape05
      Mul(linear_4, _onx_reshape05) -> _onx_mul05
    Reshape(init1_s_5, init7_s1_1) -> _onx_reshape06
      Mul(pow_1, _onx_reshape06) -> _onx_mul06
        Add(linear_4, _onx_mul06) -> add_4
    Reshape(init1_s_6, init7_s1_1) -> _onx_reshape07
      Mul(add_4, _onx_reshape07) -> _onx_mul07
        Tanh(_onx_mul07) -> tanh
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape08
      Add(tanh, _onx_reshape08) -> add_5
        Mul(_onx_mul05, add_5) -> mul_10
    Transpose(layers.0.mlp.fc2.weight, perm=[1,0]) -> _onx_transpose06
      MatMul(mul_10, _onx_transpose06) -> _onx_matmul06
        Add(_onx_matmul06, layers.0.mlp.fc2.bias) -> linear_5
          Add(linear_3, linear_5) -> add_6
      Add(add_6, embedding) -> add_7
    Mul(init1_s32_, layers.1.input_layernorm.weight) -> LayerNormalizationScalePattern_init1_s32_4
    Mul(layers.1.input_layernorm.weight, init1_s32_2) -> LayerNormalizationScalePattern_init1_s32_5
      Add(LayerNormalizationScalePattern_init1_s32_5, layers.1.input_layernorm.bias) -> LayerNormalizationScalePattern_init1_s32_6
      LayerNormalization(add_7, LayerNormalizationScalePattern_init1_s32_4, LayerNormalizationScalePattern_init1_s32_6, axis=-1, epsilon=0.00, stash_type=1) -> _onx_add04
    Transpose(layers.1.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose07
      MatMul(_onx_add04, _onx_transpose07) -> _onx_matmul07
        Add(_onx_matmul07, layers.1.self_attn.q_proj.bias) -> linear_6
          Reshape(linear_6, init7_s4_2_1024_2_16) -> view_5
            Transpose(view_5, perm=[0,2,1,3]) -> transpose_6
              Split(transpose_6, init7_s2_8_8, axis=3) -> slice_38, slice_39
                Split(slice_38, init7_s2_4_4, axis=3) -> slice_42, slice_43
                  Neg(slice_43) -> neg_2
                  Concat(neg_2, slice_42, axis=-1) -> cat_5
      Mul(cat_5, unsqueeze_11) -> mul_12
    Transpose(layers.1.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose08
      MatMul(_onx_add04, _onx_transpose08) -> _onx_matmul08
        Add(_onx_matmul08, layers.1.self_attn.k_proj.bias) -> linear_7
          Reshape(linear_7, init7_s4_2_1024_2_16) -> view_6
            Transpose(view_6, perm=[0,2,1,3]) -> transpose_7
              Split(transpose_7, init7_s2_8_8, axis=3) -> slice_40, slice_41
                Split(slice_40, init7_s2_4_4, axis=3) -> slice_44, slice_45
                  Neg(slice_45) -> neg_3
                  Concat(neg_3, slice_44, axis=-1) -> cat_6
      Mul(cat_6, unsqueeze_11) -> mul_14
    Transpose(layers.1.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose09
      MatMul(_onx_add04, _onx_transpose09) -> _onx_matmul09
        Add(_onx_matmul09, layers.1.self_attn.v_proj.bias) -> linear_8
          Reshape(linear_8, init7_s4_2_1024_2_16) -> view_7
            Transpose(view_7, perm=[0,2,1,3]) -> output_4
          Mul(slice_38, unsqueeze_10) -> mul_11
        Add(mul_11, mul_12) -> add_8
          Concat(add_8, slice_39, axis=-1) -> cat_7
    Mul(slice_40, unsqueeze_10) -> mul_13
      Add(mul_13, mul_14) -> add_9
        Concat(add_9, slice_41, axis=-1) -> output_3
          Transpose(output_3, perm=[0,1,3,2]) -> transpose_9
            MatMul(cat_7, transpose_9) -> matmul_3
    Reshape(init1_s_3, init7_s1_1) -> _onx_reshape09
      Div(matmul_3, _onx_reshape09) -> div_1
        Add(div_1, _onx_where0) -> add_10
          Softmax(add_10, axis=-1) -> softmax_1
            MatMul(softmax_1, output_4) -> matmul_4
              Transpose(matmul_4, perm=[0,2,1,3]) -> transpose_10
                Reshape(transpose_10, init7_s3_2_1024_32) -> view_8
    Transpose(layers.1.self_attn.dense.weight, perm=[1,0]) -> _onx_transpose010
      MatMul(view_8, _onx_transpose010) -> _onx_matmul010
        Add(_onx_matmul010, layers.1.self_attn.dense.bias) -> linear_9
    Transpose(layers.1.mlp.fc1.weight, perm=[1,0]) -> _onx_transpose011
      MatMul(_onx_add04, _onx_transpose011) -> _onx_matmul011
        Add(_onx_matmul011, layers.1.mlp.fc1.bias) -> linear_10
          Pow(linear_10, init1_s1_4) -> pow_2
    Reshape(init1_s_4, init7_s1_1) -> _onx_reshape010
      Mul(linear_10, _onx_reshape010) -> _onx_mul09
    Reshape(init1_s_5, init7_s1_1) -> _onx_reshape011
      Mul(pow_2, _onx_reshape011) -> _onx_mul010
        Add(linear_10, _onx_mul010) -> add_11
    Reshape(init1_s_6, init7_s1_1) -> _onx_reshape012
      Mul(add_11, _onx_reshape012) -> _onx_mul011
        Tanh(_onx_mul011) -> tanh_1
    Reshape(init1_s_2, init7_s1_1) -> _onx_reshape013
      Add(tanh_1, _onx_reshape013) -> add_12
        Mul(_onx_mul09, add_12) -> mul_18
    Transpose(layers.1.mlp.fc2.weight, perm=[1,0]) -> _onx_transpose012
      MatMul(mul_18, _onx_transpose012) -> _onx_matmul012
        Add(_onx_matmul012, layers.1.mlp.fc2.bias) -> linear_11
          Add(linear_9, linear_11) -> add_13
        Add(add_13, add_7) -> add_14
    Mul(init1_s32_, final_layernorm.weight) -> LayerNormalizationScalePattern_init1_s32_7
    Mul(final_layernorm.weight, init1_s32_2) -> LayerNormalizationScalePattern_init1_s32_8
      Add(LayerNormalizationScalePattern_init1_s32_8, final_layernorm.bias) -> LayerNormalizationScalePattern_init1_s32_9
      LayerNormalization(add_14, LayerNormalizationScalePattern_init1_s32_7, LayerNormalizationScalePattern_init1_s32_9, axis=-1, epsilon=0.00, stash_type=1) -> output_0
    output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
    output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
    output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]
    ----- function name=submod_3 domain=local_functions
    ----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
    opset: domain='' version=18
    input: 'expand_2'
    input: '_to_copy_1'
    MatMul(expand_2, _to_copy_1) -> matmul
      Transpose(matmul, perm=[0,2,1]) -> transpose
        Concat(transpose, transpose, axis=-1) -> cat
          Cos(cat) -> output_0
          Sin(cat) -> output_1
    output: name='output_0' type=? shape=?
    output: name='output_1' type=? shape=?