Phi¶
<<<
import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import PhiConfig
from transformers.models.phi.modeling_phi import PhiModel
from experimental_experiment.torch_interpreter import to_onnx
def ids_tensor(shape, vocab_size):
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(np.random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config = PhiConfig(
hidden_size=32,
num_hidden_layers=2,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
num_key_value_heads=2,
)
config._attn_implementation = "eager"
with torch.no_grad():
model = PhiModel(config)
batch, seq, vocab_size = 2, 1024, 1024
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
model(input_ids, input_mask)
onx = to_onnx(model, (input_ids, input_mask))
print(onnx_simple_text_plot(onx))
>>>
opset: domain='' version=18
input: name='input_ids' type=dtype('int64') shape=[2, 1024]
input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
init: name='p_embed_tokens_weight' type=dtype('float32') shape=(1024, 32)
init: name='p_layers_0_input_layernorm_weight' type=dtype('float32') shape=(32,)
init: name='p_layers_0_input_layernorm_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_0_self_attn_q_proj_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_0_self_attn_q_proj_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_0_self_attn_k_proj_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_0_self_attn_k_proj_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_0_self_attn_v_proj_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_0_self_attn_v_proj_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_0_self_attn_dense_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_0_self_attn_dense_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_0_mlp_fc1_weight' type=dtype('float32') shape=(16, 32)
init: name='p_layers_0_mlp_fc1_bias' type=dtype('float32') shape=(16,)
init: name='p_layers_0_mlp_fc2_weight' type=dtype('float32') shape=(32, 16)
init: name='p_layers_0_mlp_fc2_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_1_input_layernorm_weight' type=dtype('float32') shape=(32,)
init: name='p_layers_1_input_layernorm_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_1_self_attn_q_proj_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_1_self_attn_q_proj_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_1_self_attn_k_proj_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_1_self_attn_k_proj_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_1_self_attn_v_proj_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_1_self_attn_v_proj_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_1_self_attn_dense_weight' type=dtype('float32') shape=(32, 32)
init: name='p_layers_1_self_attn_dense_bias' type=dtype('float32') shape=(32,)
init: name='p_layers_1_mlp_fc1_weight' type=dtype('float32') shape=(16, 32)
init: name='p_layers_1_mlp_fc1_bias' type=dtype('float32') shape=(16,)
init: name='p_layers_1_mlp_fc2_weight' type=dtype('float32') shape=(32, 16)
init: name='p_layers_1_mlp_fc2_bias' type=dtype('float32') shape=(32,)
init: name='p_final_layernorm_weight' type=dtype('float32') shape=(32,)
init: name='p_final_layernorm_bias' type=dtype('float32') shape=(32,)
init: name='b_layers_0_self_attn_rotary_emb_cos_cached' type=dtype('float32') shape=(1024, 8)
init: name='b_layers_0_self_attn_rotary_emb_sin_cached' type=dtype('float32') shape=(1024, 8)
init: name='b_layers_1_self_attn_rotary_emb_cos_cached' type=dtype('float32') shape=(1024, 8)
init: name='b_layers_1_self_attn_rotary_emb_sin_cached' type=dtype('float32') shape=(1024, 8)
init: name='init7_s_0' type=dtype('int64') shape=() -- array([0])
init: name='init7_s_1024' type=dtype('int64') shape=() -- array([1024])
init: name='init7_s_1' type=dtype('int64') shape=() -- array([1])
init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
init: name='init7_s2_1024_1024' type=dtype('int64') shape=(2,) -- array([1024, 1024])
init: name='init7_s2_1024_1' type=dtype('int64') shape=(2,) -- array([1024, 1])
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
init: name='init7_s4_2_1_1024_1024' type=dtype('int64') shape=(4,) -- array([ 2, 1, 1024, 1024])
init: name='init1_s_' type=dtype('float32') shape=() -- array([1.], dtype=float32)
init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
init: name='init7_s2_2048_32' type=dtype('int64') shape=(2,) -- array([2048, 32])
init: name='init7_s3_2_1024_32' type=dtype('int64') shape=(3,) -- array([ 2, 1024, 32])
init: name='init7_s4_2_1024_2_16' type=dtype('int64') shape=(4,) -- array([ 2, 1024, 2, 16])
init: name='init7_s1_1024' type=dtype('int64') shape=(1,) -- array([1024])
init: name='init7_s3_2_1024_16' type=dtype('int64') shape=(3,) -- array([ 2, 1024, 16])
init: name='init1_s_3' type=dtype('float32') shape=() -- array([0.5], dtype=float32)
init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s_4' type=dtype('float32') shape=() -- array([0.045], dtype=float32)
init: name='init1_s_5' type=dtype('float32') shape=() -- array([0.798], dtype=float32)
init: name='init1_s_6' type=dtype('float32') shape=() -- array([1.], dtype=float32)
init: name='init7_s2_2048_16' type=dtype('int64') shape=(2,) -- array([2048, 16])
init: name='init1_s_8' type=dtype('float32') shape=() -- array([0.5], dtype=float32)
init: name='init1_s1_5' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s_9' type=dtype('float32') shape=() -- array([0.045], dtype=float32)
init: name='init1_s_10' type=dtype('float32') shape=() -- array([0.798], dtype=float32)
init: name='init1_s_11' type=dtype('float32') shape=() -- array([1.], dtype=float32)
init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_1_2' type=dtype('int64') shape=(2,) -- array([1, 2])
init: name='init1_s_12' type=dtype('float32') shape=() -- array([0.25], dtype=float32)
init: name='init1_s_13' type=dtype('float32') shape=() -- array([0.25], dtype=float32)
init: name='init7_s2_8_8' type=dtype('int64') shape=(2,) -- array([8, 8])
init: name='init7_s2_4_4' type=dtype('int64') shape=(2,) -- array([4, 4])
ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
Unsqueeze(arange, init7_s1_0) -> unsqueeze
Gather(p_embed_tokens_weight, input_ids) -> embedding
LayerNormalization(embedding, p_layers_0_input_layernorm_weight, p_layers_0_input_layernorm_bias, axis=-1, epsilon=0.00) -> native_layer_norm#0, native_layer_norm#1, native_layer_norm#2
Reshape(native_layer_norm#0, init7_s2_2048_32) -> view_1
Gemm(view_1, p_layers_0_mlp_fc1_weight, p_layers_0_mlp_fc1_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_4
Reshape(addmm_4, init7_s3_2_1024_16) -> view_19
Mul(view_19, init1_s_3) -> _onx_mul0
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
Add(arange_1, init7_s_1) -> add
Reshape(add, init7_s2_1024_1) -> view
Less(arange_1, view) -> lt
Where(lt, init1_s1_, full) -> _onx_where0
Unsqueeze(_onx_where0, init7_s2_0_1) -> unsqueeze_2
Expand(unsqueeze_2, init7_s4_2_1_1024_1024) -> expand
Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_4
Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
Sub(init1_s_, expand_1) -> rsub
Cast(rsub, to=9) -> _to_copy_2
Where(_to_copy_2, init1_s1_2, rsub) -> _onx_where02
Cast(_onx_where02, to=9) -> _to_copy_4
Where(_to_copy_4, init1_s1_3, expand) -> _onx_where03
Gemm(view_1, p_layers_0_self_attn_v_proj_weight, p_layers_0_self_attn_v_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_2
Reshape(addmm_2, init7_s4_2_1024_2_16) -> view_9
Transpose(view_9, perm=[0,2,1,3]) -> output_2
Gemm(view_1, p_layers_0_self_attn_k_proj_weight, p_layers_0_self_attn_k_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_1
Reshape(addmm_1, init7_s4_2_1024_2_16) -> view_8
Transpose(view_8, perm=[0,2,1,3]) -> transpose_1
Split(transpose_1, init7_s2_8_8, axis=3) -> slice_9, slice_10
Split(slice_9, init7_s2_4_4, axis=3) -> slice_13, slice_14
Neg(slice_14) -> neg_1
Concat(neg_1, slice_13, axis=-1) -> cat_1
Gemm(view_1, p_layers_0_self_attn_q_proj_weight, p_layers_0_self_attn_q_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
Reshape(addmm, init7_s4_2_1024_2_16) -> view_7
Transpose(view_7, perm=[0,2,1,3]) -> transpose
Split(transpose, init7_s2_8_8, axis=3) -> slice_7, slice_8
Split(slice_7, init7_s2_4_4, axis=3) -> slice_11, slice_12
Neg(slice_12) -> neg
Concat(neg, slice_11, axis=-1) -> cat
Slice(b_layers_0_self_attn_rotary_emb_cos_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_5
Gather(slice_5, unsqueeze, axis=0) -> index
Unsqueeze(index, init7_s1_1) -> unsqueeze_5
Mul(slice_7, unsqueeze_5) -> mul
Slice(b_layers_0_self_attn_rotary_emb_sin_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_6
Gather(slice_6, unsqueeze, axis=0) -> index_1
Unsqueeze(index_1, init7_s1_1) -> unsqueeze_6
Mul(cat, unsqueeze_6) -> mul_1
Add(mul, mul_1) -> add_1
Concat(add_1, slice_8, axis=-1) -> cat_2
Mul(slice_9, unsqueeze_5) -> mul_2
Mul(cat_1, unsqueeze_6) -> mul_3
Add(mul_2, mul_3) -> add_2
Concat(add_2, slice_10, axis=-1) -> output_1
Transpose(output_1, perm=[0,1,3,2]) -> transpose_3
MatMul(cat_2, transpose_3) -> view_12
Mul(view_12, init1_s_12) -> div
Add(div, _onx_where03) -> add_3
Softmax(add_3, axis=-1) -> _softmax
MatMul(_softmax, output_2) -> view_14
Transpose(view_14, perm=[0,2,1,3]) -> transpose_4
Reshape(transpose_4, init7_s2_2048_32) -> view_16
Gemm(view_16, p_layers_0_self_attn_dense_weight, p_layers_0_self_attn_dense_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_3
Pow(view_19, init1_s1_4) -> pow_1
Mul(pow_1, init1_s_4) -> _onx_mul02
Add(view_19, _onx_mul02) -> add_4
Mul(add_4, init1_s_5) -> _onx_mul03
Tanh(_onx_mul03) -> tanh
Add(tanh, init1_s_6) -> add_5
Mul(_onx_mul0, add_5) -> mul_7
Reshape(mul_7, init7_s2_2048_16) -> view_20
Gemm(view_20, p_layers_0_mlp_fc2_weight, p_layers_0_mlp_fc2_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_5
Add(addmm_3, addmm_5) -> Reshape2Of3PatternL_add_6
Reshape(Reshape2Of3PatternL_add_6, init7_s3_2_1024_32) -> add_6
Add(add_6, embedding) -> add_7
LayerNormalization(add_7, p_layers_1_input_layernorm_weight, p_layers_1_input_layernorm_bias, axis=-1, epsilon=0.00) -> native_layer_norm_1#0, native_layer_norm_1#1, native_layer_norm_1#2
Reshape(native_layer_norm_1#0, init7_s2_2048_32) -> view_22
Gemm(view_22, p_layers_1_mlp_fc1_weight, p_layers_1_mlp_fc1_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_10
Reshape(addmm_10, init7_s3_2_1024_16) -> view_40
Mul(view_40, init1_s_8) -> _onx_mul04
Gemm(view_22, p_layers_1_self_attn_v_proj_weight, p_layers_1_self_attn_v_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_8
Reshape(addmm_8, init7_s4_2_1024_2_16) -> view_30
Transpose(view_30, perm=[0,2,1,3]) -> output_4
Gemm(view_22, p_layers_1_self_attn_k_proj_weight, p_layers_1_self_attn_k_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_7
Reshape(addmm_7, init7_s4_2_1024_2_16) -> view_29
Transpose(view_29, perm=[0,2,1,3]) -> transpose_6
Split(transpose_6, init7_s2_8_8, axis=3) -> slice_19, slice_20
Split(slice_19, init7_s2_4_4, axis=3) -> slice_23, slice_24
Neg(slice_24) -> neg_3
Concat(neg_3, slice_23, axis=-1) -> cat_5
Gemm(view_22, p_layers_1_self_attn_q_proj_weight, p_layers_1_self_attn_q_proj_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_6
Reshape(addmm_6, init7_s4_2_1024_2_16) -> view_28
Transpose(view_28, perm=[0,2,1,3]) -> transpose_5
Split(transpose_5, init7_s2_8_8, axis=3) -> slice_17, slice_18
Split(slice_17, init7_s2_4_4, axis=3) -> slice_21, slice_22
Neg(slice_22) -> neg_2
Concat(neg_2, slice_21, axis=-1) -> cat_4
Slice(b_layers_1_self_attn_rotary_emb_cos_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_15
Gather(slice_15, unsqueeze, axis=0) -> index_2
Unsqueeze(index_2, init7_s1_1) -> unsqueeze_7
Mul(slice_17, unsqueeze_7) -> mul_8
Slice(b_layers_1_self_attn_rotary_emb_sin_cached, init7_s1_0, init7_s1_1024, init7_s1_0) -> slice_16
Gather(slice_16, unsqueeze, axis=0) -> index_3
Unsqueeze(index_3, init7_s1_1) -> unsqueeze_8
Mul(cat_4, unsqueeze_8) -> mul_9
Add(mul_8, mul_9) -> add_8
Concat(add_8, slice_18, axis=-1) -> cat_6
Mul(slice_19, unsqueeze_7) -> mul_10
Mul(cat_5, unsqueeze_8) -> mul_11
Add(mul_10, mul_11) -> add_9
Concat(add_9, slice_20, axis=-1) -> output_3
Transpose(output_3, perm=[0,1,3,2]) -> transpose_8
MatMul(cat_6, transpose_8) -> view_33
Mul(view_33, init1_s_13) -> div_1
Add(div_1, _onx_where03) -> add_10
Softmax(add_10, axis=-1) -> _softmax_1
MatMul(_softmax_1, output_4) -> view_35
Transpose(view_35, perm=[0,2,1,3]) -> transpose_9
Reshape(transpose_9, init7_s2_2048_32) -> view_37
Gemm(view_37, p_layers_1_self_attn_dense_weight, p_layers_1_self_attn_dense_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_9
Pow(view_40, init1_s1_5) -> pow_2
Mul(pow_2, init1_s_9) -> _onx_mul05
Add(view_40, _onx_mul05) -> add_11
Mul(add_11, init1_s_10) -> _onx_mul06
Tanh(_onx_mul06) -> tanh_1
Add(tanh_1, init1_s_11) -> add_12
Mul(_onx_mul04, add_12) -> mul_15
Reshape(mul_15, init7_s2_2048_16) -> view_41
Gemm(view_41, p_layers_1_mlp_fc2_weight, p_layers_1_mlp_fc2_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm_11
Add(addmm_9, addmm_11) -> Reshape2Of3PatternL_add_13
Reshape(Reshape2Of3PatternL_add_13, init7_s3_2_1024_32) -> add_13
Add(add_13, add_7) -> add_14
LayerNormalization(add_14, p_final_layernorm_weight, p_final_layernorm_bias, axis=-1, epsilon=0.00) -> output_0, native_layer_norm_2#1, native_layer_norm_2#2
output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]