Phi3¶
<<<
import numpy as np
import torch
from transformers import Phi3Config, Phi3Model
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
def ids_tensor(shape, vocab_size):
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(np.random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config = Phi3Config(
hidden_size=32,
num_hidden_layers=2,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
num_key_value_heads=2,
pad_token_id=1023,
)
config._attn_implementation = "eager"
with torch.no_grad():
model = Phi3Model(config)
batch, seq, vocab_size = 2, 1024, 1024
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
model(input_ids, input_mask)
onx = to_onnx(model, (input_ids, input_mask))
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[2, 1024]
input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
init: name='b_layers_0_self_attn_rotary_emb_inv_freq' type=dtype('float32') shape=(8,)
init: name='b_layers_1_self_attn_rotary_emb_inv_freq' type=dtype('float32') shape=(8,)
init: name='init7_s_0' type=dtype('int64') shape=() -- array([0])
init: name='init7_s_1024' type=dtype('int64') shape=() -- array([1024])
init: name='init7_s_1' type=dtype('int64') shape=() -- array([1])
init: name='init7_s2_1024_1024' type=dtype('int64') shape=(2,) -- array([1024, 1024])
init: name='init7_s2_-1_1' type=dtype('int64') shape=(2,) -- array([-1, 1])
init: name='init7_s1_1' type=dtype('int64') shape=(1,) -- array([1])
init: name='init7_s4_2_1_1024_1024' type=dtype('int64') shape=(4,) -- array([ 2, 1, 1024, 1024])
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([-3.403e+38], dtype=float32)
init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([2.], dtype=float32)
init: name='init7_s1_-1' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init1_s_2' type=dtype('float32') shape=() -- array([1.e-05], dtype=float32)
init: name='init7_s4_2_1024_2_16' type=dtype('int64') shape=(4,) -- array([ 2, 1024, 2, 16])
init: name='init1_s_3' type=dtype('float32') shape=() -- array([4.], dtype=float32)
init: name='init7_s3_2_1024_32' type=dtype('int64') shape=(3,) -- array([ 2, 1024, 32])
init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_1_2' type=dtype('int64') shape=(2,) -- array([1, 2])
init: name='init7_s2_0_2' type=dtype('int64') shape=(2,) -- array([0, 2])
init: name='init7_s2_0_12' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s3_32_32_32' type=dtype('int64') shape=(3,) -- array([32, 32, 32])
init: name='init7_s2_8_8' type=dtype('int64') shape=(2,) -- array([8, 8])
init: name='layers.0.input_layernorm.weight' type=dtype('float32') shape=(32,)
init: name='layers.0.post_attention_layernorm.weight' type=dtype('float32') shape=(32,)
init: name='layers.1.input_layernorm.weight' type=dtype('float32') shape=(32,)
init: name='layers.1.post_attention_layernorm.weight' type=dtype('float32') shape=(32,)
init: name='norm.weight' type=dtype('float32') shape=(32,)
init: name='embed_tokens.weight' type=dtype('float32') shape=(1024, 32)
init: name='layers.0.self_attn.qkv_proj.weight' type=dtype('float32') shape=(96, 32)
init: name='layers.0.self_attn.o_proj.weight' type=dtype('float32') shape=(32, 32)
init: name='layers.0.mlp.gate_up_proj.weight' type=dtype('float32') shape=(32, 32)
init: name='layers.0.mlp.down_proj.weight' type=dtype('float32') shape=(32, 16)
init: name='layers.1.self_attn.qkv_proj.weight' type=dtype('float32') shape=(96, 32)
init: name='layers.1.self_attn.o_proj.weight' type=dtype('float32') shape=(32, 32)
init: name='layers.1.mlp.gate_up_proj.weight' type=dtype('float32') shape=(32, 32)
init: name='layers.1.mlp.down_proj.weight' type=dtype('float32') shape=(32, 16)
ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
Gather(embed_tokens.weight, input_ids) -> embedding
Pow(embedding, init1_s1_2) -> pow_1
ReduceMean(pow_1, init7_s1_-1, keepdims=1) -> mean
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
Unsqueeze(arange, init7_s2_0_12) -> unsqueeze_9
Cast(unsqueeze_9, to=1) -> _to_copy_4
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
Reshape(arange, init7_s2_-1_1) -> view
Greater(arange_1, view) -> gt
Cast(gt, to=1) -> _onx_cast0
Mul(full, _onx_cast0) -> _onx_mul0
Unsqueeze(_onx_mul0, init7_s2_0_1) -> unsqueeze_4
Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
Add(expand_1, unsqueeze_6) -> add
Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
Equal(add, _onx_reshape0) -> eq
Where(eq, init1_s1_, expand_1) -> _onx_where0
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape02
Add(mean, _onx_reshape02) -> add_1
Sqrt(add_1) -> _onx_sqrt0
Reciprocal(_onx_sqrt0) -> rsqrt
Mul(embedding, rsqrt) -> mul_1
Mul(layers.0.input_layernorm.weight, mul_1) -> mul_2
Transpose(layers.0.self_attn.qkv_proj.weight, perm=[1,0]) -> _onx_transpose0
MatMul(mul_2, _onx_transpose0) -> linear
Split(linear, init7_s3_32_32_32, axis=2) -> slice_21, slice_22, slice_23
Reshape(slice_21, init7_s4_2_1024_2_16) -> view_1
Transpose(view_1, perm=[0,2,1,3]) -> transpose
Split(transpose, init7_s2_8_8, axis=3) -> slice_27, slice_28
Neg(slice_28) -> neg
Concat(neg, slice_27, axis=-1) -> cat_1
Reshape(slice_22, init7_s4_2_1024_2_16) -> view_2
Transpose(view_2, perm=[0,2,1,3]) -> transpose_1
Split(transpose_1, init7_s2_8_8, axis=3) -> slice_29, slice_30
Neg(slice_30) -> neg_1
Concat(neg_1, slice_29, axis=-1) -> cat_2
Reshape(slice_23, init7_s4_2_1024_2_16) -> view_3
Transpose(view_3, perm=[0,2,1,3]) -> output_2
Unsqueeze(b_layers_0_self_attn_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
submod_5[local_functions](unsqueeze_8, _to_copy_4) -> wrap_with_autocast#0, wrap_with_autocast#1
Unsqueeze(wrap_with_autocast#0, init7_s1_1) -> unsqueeze_10
Mul(transpose, unsqueeze_10) -> mul_3
Unsqueeze(wrap_with_autocast#1, init7_s1_1) -> unsqueeze_11
Mul(cat_1, unsqueeze_11) -> mul_4
Add(mul_3, mul_4) -> add_2
Mul(transpose_1, unsqueeze_10) -> mul_5
Mul(cat_2, unsqueeze_11) -> mul_6
Add(mul_5, mul_6) -> output_1
Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
MatMul(add_2, transpose_4) -> matmul_1
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape03
Div(matmul_1, _onx_reshape03) -> div
Add(div, _onx_where0) -> add_4
Softmax(add_4, axis=-1) -> softmax
MatMul(softmax, output_2) -> matmul_2
Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
Reshape(transpose_5, init7_s3_2_1024_32) -> view_4
Transpose(layers.0.self_attn.o_proj.weight, perm=[1,0]) -> _onx_transpose02
MatMul(view_4, _onx_transpose02) -> linear_1
Add(embedding, linear_1) -> add_5
Pow(add_5, init1_s1_2) -> pow_2
ReduceMean(pow_2, init7_s1_-1, keepdims=1) -> mean_1
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape04
Add(mean_1, _onx_reshape04) -> add_6
Sqrt(add_6) -> _onx_sqrt02
Reciprocal(_onx_sqrt02) -> rsqrt_1
Mul(add_5, rsqrt_1) -> mul_7
Mul(layers.0.post_attention_layernorm.weight, mul_7) -> mul_8
Transpose(layers.0.mlp.gate_up_proj.weight, perm=[1,0]) -> _onx_transpose03
MatMul(mul_8, _onx_transpose03) -> linear_2
Split(linear_2, axis=-1, num_outputs=2) -> split#0, split#1
Sigmoid(split#0) -> _onx_sigmoid0
Mul(split#0, _onx_sigmoid0) -> silu
Mul(split#1, silu) -> mul_9
Transpose(layers.0.mlp.down_proj.weight, perm=[1,0]) -> _onx_transpose04
MatMul(mul_9, _onx_transpose04) -> linear_3
Add(add_5, linear_3) -> add_7
Pow(add_7, init1_s1_2) -> pow_3
ReduceMean(pow_3, init7_s1_-1, keepdims=1) -> mean_2
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape05
Add(mean_2, _onx_reshape05) -> add_8
Sqrt(add_8) -> _onx_sqrt03
Reciprocal(_onx_sqrt03) -> rsqrt_2
Mul(add_7, rsqrt_2) -> mul_10
Mul(layers.1.input_layernorm.weight, mul_10) -> mul_11
Transpose(layers.1.self_attn.qkv_proj.weight, perm=[1,0]) -> _onx_transpose05
MatMul(mul_11, _onx_transpose05) -> linear_4
Split(linear_4, init7_s3_32_32_32, axis=2) -> slice_37, slice_38, slice_39
Reshape(slice_37, init7_s4_2_1024_2_16) -> view_5
Transpose(view_5, perm=[0,2,1,3]) -> transpose_6
Split(transpose_6, init7_s2_8_8, axis=3) -> slice_43, slice_44
Neg(slice_44) -> neg_2
Concat(neg_2, slice_43, axis=-1) -> cat_4
Reshape(slice_38, init7_s4_2_1024_2_16) -> view_6
Transpose(view_6, perm=[0,2,1,3]) -> transpose_7
Split(transpose_7, init7_s2_8_8, axis=3) -> slice_45, slice_46
Neg(slice_46) -> neg_3
Concat(neg_3, slice_45, axis=-1) -> cat_5
Reshape(slice_39, init7_s4_2_1024_2_16) -> view_7
Transpose(view_7, perm=[0,2,1,3]) -> output_4
Unsqueeze(b_layers_1_self_attn_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_13
submod_6[local_functions](unsqueeze_13, _to_copy_4) -> wrap_with_autocast_1#0, wrap_with_autocast_1#1
Unsqueeze(wrap_with_autocast_1#0, init7_s1_1) -> unsqueeze_15
Mul(transpose_6, unsqueeze_15) -> mul_12
Unsqueeze(wrap_with_autocast_1#1, init7_s1_1) -> unsqueeze_16
Mul(cat_4, unsqueeze_16) -> mul_13
Add(mul_12, mul_13) -> add_9
Mul(transpose_7, unsqueeze_15) -> mul_14
Mul(cat_5, unsqueeze_16) -> mul_15
Add(mul_14, mul_15) -> output_3
Transpose(output_3, perm=[0,1,3,2]) -> transpose_10
MatMul(add_9, transpose_10) -> matmul_4
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape06
Div(matmul_4, _onx_reshape06) -> div_1
Add(div_1, _onx_where0) -> add_11
Softmax(add_11, axis=-1) -> softmax_1
MatMul(softmax_1, output_4) -> matmul_5
Transpose(matmul_5, perm=[0,2,1,3]) -> transpose_11
Reshape(transpose_11, init7_s3_2_1024_32) -> view_8
Transpose(layers.1.self_attn.o_proj.weight, perm=[1,0]) -> _onx_transpose06
MatMul(view_8, _onx_transpose06) -> linear_5
Add(add_7, linear_5) -> add_12
Pow(add_12, init1_s1_2) -> pow_4
ReduceMean(pow_4, init7_s1_-1, keepdims=1) -> mean_3
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape07
Add(mean_3, _onx_reshape07) -> add_13
Sqrt(add_13) -> _onx_sqrt04
Reciprocal(_onx_sqrt04) -> rsqrt_3
Mul(add_12, rsqrt_3) -> mul_16
Mul(layers.1.post_attention_layernorm.weight, mul_16) -> mul_17
Transpose(layers.1.mlp.gate_up_proj.weight, perm=[1,0]) -> _onx_transpose07
MatMul(mul_17, _onx_transpose07) -> linear_6
Split(linear_6, axis=-1, num_outputs=2) -> split_1#0, split_1#1
Sigmoid(split_1#0) -> _onx_sigmoid02
Mul(split_1#0, _onx_sigmoid02) -> silu_1
Mul(split_1#1, silu_1) -> mul_18
Transpose(layers.1.mlp.down_proj.weight, perm=[1,0]) -> _onx_transpose08
MatMul(mul_18, _onx_transpose08) -> linear_7
Add(add_12, linear_7) -> add_14
Pow(add_14, init1_s1_2) -> pow_5
ReduceMean(pow_5, init7_s1_-1, keepdims=1) -> mean_4
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape08
Add(mean_4, _onx_reshape08) -> add_15
Sqrt(add_15) -> _onx_sqrt05
Reciprocal(_onx_sqrt05) -> rsqrt_4
Mul(add_14, rsqrt_4) -> mul_19
Mul(norm.weight, mul_19) -> output_0
output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]
----- function name=submod_5 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'expand_2'
input: '_to_copy_4'
MatMul(expand_2, _to_copy_4) -> matmul
Transpose(matmul, perm=[0,2,1]) -> transpose_3
Concat(transpose_3, transpose_3, axis=-1) -> cat
Cos(cat) -> output_0
Sin(cat) -> output_1
output: name='output_0' type=? shape=?
output: name='output_1' type=? shape=?
----- function name=submod_6 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='local_functions' version=1
input: 'expand_3'
input: '_to_copy_16'
MatMul(expand_3, _to_copy_16) -> matmul_3
Transpose(matmul_3, perm=[0,2,1]) -> transpose_9
Concat(transpose_9, transpose_9, axis=-1) -> cat_3
Cos(cat_3) -> output_0
Sin(cat_3) -> output_1
output: name='output_0' type=? shape=?
output: name='output_1' type=? shape=?