PhiΒΆ
<<<
import numpy as np
import torch
from transformers import PhiConfig
from transformers.models.phi.modeling_phi import PhiModel
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
from experimental_experiment.torch_interpreter.onnx_export_errors import (
bypass_export_some_errors,
)
def ids_tensor(shape, vocab_size):
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(np.random.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config = PhiConfig(
hidden_size=32,
num_hidden_layers=2,
vocab_size=1024,
intermediate_size=16,
max_position_embeddings=512,
num_attention_heads=2,
num_key_value_heads=2,
)
config._attn_implementation = "eager"
with torch.no_grad(), bypass_export_some_errors(
patch_transformers=True,
replace_dynamic_cache=True,
) as modificator:
model = PhiModel(config)
batch, seq, vocab_size = 2, 1024, 1024
input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
model(input_ids, input_mask)
onx = to_onnx(
model,
modificator((input_ids, input_mask)),
export_options=ExportOptions(decomposition_table="default"),
)
print(pretty_onnx(onx))
>>>
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_ids' type=dtype('int64') shape=[2, 1024]
input: name='attention_mask' type=dtype('float32') shape=[2, 1024]
init: name='b_rotary_emb_inv_freq' type=float32 shape=(4,) -- array([1. , 0.1 , 0.01 , 0.001], dtype=float32)-- DynamoInterpret.placeholder.0
init: name='init7_s_0' type=int64 shape=() -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s_1024' type=int64 shape=() -- array([1024]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s_1' type=int64 shape=() -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s2_1024_1024' type=int64 shape=(2,) -- array([1024, 1024])-- Opset.make_node.1/Shape
init: name='init7_s2_-1_1' type=int64 shape=(2,) -- array([-1, 1]) -- Opset.make_node.1/Shape
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s4_2_1_1024_1024' type=int64 shape=(4,) -- array([ 2, 1, 1024, 1024])-- GraphBuilder.make_shape_from_results.shape
init: name='init1_s_' type=float32 shape=() -- array([0.], dtype=float32)-- shape_type_compute._cast_inputs.0
init: name='init1_s1_' type=float32 shape=(1,) -- array([-3.403e+38], dtype=float32)-- Opset.make_node.1/Small
init: name='init1_s_2' type=float32 shape=() -- array([1.], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.0##shape_type_compute._cast_inputs.0
init: name='init7_s4_2_1024_-1_16' type=int64 shape=(4,) -- array([ 2, 1024, -1, 16])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_3' type=float32 shape=() -- array([0.25], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)
init: name='init7_s3_2_1024_-1' type=int64 shape=(3,) -- array([ 2, 1024, -1])-- Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init1_s_4' type=float32 shape=() -- array([0.5], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)
init: name='init1_s1_4' type=float32 shape=(1,) -- array([3.], dtype=float32)-- Opset.make_node.1/Small##Opset.make_node.1/Small
init: name='init1_s_5' type=float32 shape=() -- array([0.045], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)
init: name='init1_s_6' type=float32 shape=() -- array([0.798], dtype=float32)-- shape_type_compute._cast_inputs.1(mul_Tensor)##shape_type_compute._cast_inputs.1(mul_Tensor)
init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1]) -- UnsqueezeUnsqueezePattern.apply.new_axis##UnsqueezeUnsqueezePattern.apply.new_axis
init: name='init7_s2_1_2' type=int64 shape=(2,) -- array([1, 2]) -- UnsqueezeUnsqueezePattern.apply.new_axis
init: name='init7_s2_0_2' type=int64 shape=(2,) -- array([0, 2]) -- UnsqueezeUnsqueezePattern.apply.new_axis
init: name='init1_s32_' type=float32 shape=(32,) -- LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale##LayerNormalizationPattern.apply.scale
init: name='init1_s32_2' type=float32 shape=(32,) -- LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias##LayerNormalizationPattern.apply.bias
init: name='init7_s2_-1_32' type=int64 shape=(2,) -- array([-1, 32]) -- MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3##MatMulAddPattern.new_shape.1##MatMulAddPattern.new_shape.3
init: name='init7_s3_2_1024_-12' type=int64 shape=(3,) -- array([ 2, 1024, -1])-- MatMulAddPattern.new_shape.2##MatMulAddPattern.new_shape.2
init: name='init7_s2_8_8' type=int64 shape=(2,) -- array([8, 8]) -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
init: name='init7_s2_4_4' type=int64 shape=(2,) -- array([4, 4]) -- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
init: name='embed_tokens.weight' type=float32 shape=(1024, 32) -- DynamoInterpret.placeholder.1/P(embed_tokens.weight)
init: name='layers.0.self_attn.q_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.q_proj.weight)
init: name='layers.0.self_attn.k_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.k_proj.weight)
init: name='layers.0.self_attn.v_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.v_proj.weight)
init: name='layers.0.self_attn.dense.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.0.self_attn.dense.weight)
init: name='layers.0.mlp.fc1.weight' type=float32 shape=(16, 32) -- DynamoInterpret.placeholder.1/P(layers.0.mlp.fc1.weight)
init: name='layers.0.mlp.fc2.weight' type=float32 shape=(32, 16) -- DynamoInterpret.placeholder.1/P(layers.0.mlp.fc2.weight)
init: name='layers.1.self_attn.q_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.q_proj.weight)
init: name='layers.1.self_attn.k_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.k_proj.weight)
init: name='layers.1.self_attn.v_proj.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.v_proj.weight)
init: name='layers.1.self_attn.dense.weight' type=float32 shape=(32, 32)-- DynamoInterpret.placeholder.1/P(layers.1.self_attn.dense.weight)
init: name='layers.1.mlp.fc1.weight' type=float32 shape=(16, 32) -- DynamoInterpret.placeholder.1/P(layers.1.mlp.fc1.weight)
init: name='layers.1.mlp.fc2.weight' type=float32 shape=(32, 16) -- DynamoInterpret.placeholder.1/P(layers.1.mlp.fc2.weight)
ConstantOfShape(init7_s2_1024_1024, value=[-3.402823...) -> full
Trilu(full, init7_s_1, upper=1) -> triu
Gather(embed_tokens.weight, input_ids) -> embedding
LayerNormalization(embedding, init1_s32_, init1_s32_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div__onx_sub_clone_100
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange
Unsqueeze(arange, init7_s2_0_1) -> unsqueeze_9
Cast(unsqueeze_9, to=1) -> _to_copy_1
Range(init7_s_0, init7_s_1024, init7_s_1) -> arange_1
Reshape(arange, init7_s2_-1_1) -> view
Greater(arange_1, view) -> gt
Cast(gt, to=1) -> _onx_cast_gt0
Mul(triu, _onx_cast_gt0) -> _onx_mul_triu0
Unsqueeze(_onx_mul_triu0, init7_s2_0_1) -> unsqueeze_4
Expand(unsqueeze_4, init7_s4_2_1_1024_1024) -> expand_1
Unsqueeze(attention_mask, init7_s2_1_2) -> unsqueeze_6
Add(expand_1, unsqueeze_6) -> add
Reshape(init1_s_, init7_s1_1) -> _onx_reshape_init1_s_0
Equal(add, _onx_reshape_init1_s_0) -> eq
Where(eq, init1_s1_, expand_1) -> masked_fill
Unsqueeze(b_rotary_emb_inv_freq, init7_s2_0_2) -> unsqueeze_8
MatMul(unsqueeze_8, _to_copy_1) -> matmul
Transpose(matmul, perm=[0,2,1]) -> transpose
Concat(transpose, transpose, axis=-1) -> cat
Cos(cat) -> cos
Unsqueeze(cos, init7_s1_1) -> unsqueeze_10
Sin(cat) -> sin
Unsqueeze(sin, init7_s1_1) -> unsqueeze_11
Transpose(layers.0.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose_p_layers_0_self_attn_q_proj_weight0
MatMul(_onx_div__onx_sub_clone_100, _onx_transpose_p_layers_0_self_attn_q_proj_weight0) -> _onx_matmul_layer_norm0
Reshape(_onx_matmul_layer_norm0, init7_s4_2_1024_-1_16) -> view_1
Transpose(view_1, perm=[0,2,1,3]) -> transpose_1
Split(transpose_1, init7_s2_8_8, axis=3) -> slice_24, slice_25
Split(slice_24, init7_s2_4_4, axis=3) -> slice_28, slice_29
Neg(slice_29) -> neg
Concat(neg, slice_28, axis=-1) -> cat_1
Mul(cat_1, unsqueeze_11) -> mul_4
Transpose(layers.0.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose_p_layers_0_self_attn_k_proj_weight0
MatMul(_onx_div__onx_sub_clone_100, _onx_transpose_p_layers_0_self_attn_k_proj_weight0) -> _onx_matmul_layer_norm02
Reshape(_onx_matmul_layer_norm02, init7_s4_2_1024_-1_16) -> view_2
Transpose(view_2, perm=[0,2,1,3]) -> transpose_2
Split(transpose_2, init7_s2_8_8, axis=3) -> slice_26, slice_27
Split(slice_26, init7_s2_4_4, axis=3) -> slice_30, slice_31
Neg(slice_31) -> neg_1
Concat(neg_1, slice_30, axis=-1) -> cat_2
Mul(cat_2, unsqueeze_11) -> mul_6
Transpose(layers.0.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose_p_layers_0_self_attn_v_proj_weight0
MatMul(_onx_div__onx_sub_clone_100, _onx_transpose_p_layers_0_self_attn_v_proj_weight0) -> _onx_matmul_layer_norm03
Reshape(_onx_matmul_layer_norm03, init7_s4_2_1024_-1_16) -> view_3
Transpose(view_3, perm=[0,2,1,3]) -> output_3
Mul(slice_24, unsqueeze_10) -> mul_3
Add(mul_3, mul_4) -> add_1
Concat(add_1, slice_25, axis=-1) -> cat_3
Mul(slice_26, unsqueeze_10) -> mul_5
Add(mul_5, mul_6) -> add_2
Concat(add_2, slice_27, axis=-1) -> output_1
Transpose(output_1, perm=[0,1,3,2]) -> transpose_4
MatMul(cat_3, transpose_4) -> matmul_1
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape_init1_s_30
Mul(matmul_1, _onx_reshape_init1_s_30) -> _onx_mul_matmul_10
Add(_onx_mul_matmul_10, masked_fill) -> add_3
Softmax(add_3, axis=-1) -> softmax
MatMul(softmax, output_3) -> matmul_2
Transpose(matmul_2, perm=[0,2,1,3]) -> transpose_5
Reshape(transpose_5, init7_s3_2_1024_-1) -> view_4
Reshape(view_4, init7_s2_-1_32) -> MatMulAddPattern--view_4
Transpose(layers.0.mlp.fc1.weight, perm=[1,0]) -> _onx_transpose_p_layers_0_mlp_fc1_weight0
MatMul(_onx_div__onx_sub_clone_100, _onx_transpose_p_layers_0_mlp_fc1_weight0) -> _onx_matmul_layer_norm04
Pow(_onx_matmul_layer_norm04, init1_s1_4) -> pow_1
Reshape(init1_s_4, init7_s1_1) -> _onx_reshape_init1_s_40
Mul(_onx_matmul_layer_norm04, _onx_reshape_init1_s_40) -> _onx_mul_linear_40
Reshape(init1_s_5, init7_s1_1) -> _onx_reshape_init1_s_50
Mul(pow_1, _onx_reshape_init1_s_50) -> _onx_mul_pow_10
Add(_onx_matmul_layer_norm04, _onx_mul_pow_10) -> add_4
Reshape(init1_s_6, init7_s1_1) -> _onx_reshape_init1_s_60
Mul(add_4, _onx_reshape_init1_s_60) -> _onx_mul_add_40
Tanh(_onx_mul_add_40) -> tanh
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape_init1_s_203
Add(tanh, _onx_reshape_init1_s_203) -> add_5
Mul(_onx_mul_linear_40, add_5) -> mul_11
Transpose(layers.0.mlp.fc2.weight, perm=[1,0]) -> _onx_transpose_p_layers_0_mlp_fc2_weight0
MatMul(mul_11, _onx_transpose_p_layers_0_mlp_fc2_weight0) -> _onx_matmul_mul_110
Reshape(_onx_matmul_mul_110, init7_s2_-1_32) -> MatMulAddPattern--view_42
Gemm(MatMulAddPattern--view_4, layers.0.self_attn.dense.weight, MatMulAddPattern--view_42, transB=1) -> MatMulAddPattern--view_43
Reshape(MatMulAddPattern--view_43, init7_s3_2_1024_-12) -> add_6
Add(add_6, embedding) -> add_7
LayerNormalization(add_7, init1_s32_, init1_s32_2, axis=-1, epsilon=0.00, stash_type=1) -> _onx_div__onx_sub_add_700
Transpose(layers.1.self_attn.q_proj.weight, perm=[1,0]) -> _onx_transpose_p_layers_1_self_attn_q_proj_weight0
MatMul(_onx_div__onx_sub_add_700, _onx_transpose_p_layers_1_self_attn_q_proj_weight0) -> _onx_matmul_layer_norm_10
Reshape(_onx_matmul_layer_norm_10, init7_s4_2_1024_-1_16) -> view_5
Transpose(view_5, perm=[0,2,1,3]) -> transpose_6
Split(transpose_6, init7_s2_8_8, axis=3) -> slice_38, slice_39
Split(slice_38, init7_s2_4_4, axis=3) -> slice_42, slice_43
Neg(slice_43) -> neg_2
Concat(neg_2, slice_42, axis=-1) -> cat_5
Mul(cat_5, unsqueeze_11) -> mul_13
Transpose(layers.1.self_attn.k_proj.weight, perm=[1,0]) -> _onx_transpose_p_layers_1_self_attn_k_proj_weight0
MatMul(_onx_div__onx_sub_add_700, _onx_transpose_p_layers_1_self_attn_k_proj_weight0) -> _onx_matmul_layer_norm_102
Reshape(_onx_matmul_layer_norm_102, init7_s4_2_1024_-1_16) -> view_6
Transpose(view_6, perm=[0,2,1,3]) -> transpose_7
Split(transpose_7, init7_s2_8_8, axis=3) -> slice_40, slice_41
Split(slice_40, init7_s2_4_4, axis=3) -> slice_44, slice_45
Neg(slice_45) -> neg_3
Concat(neg_3, slice_44, axis=-1) -> cat_6
Mul(cat_6, unsqueeze_11) -> mul_15
Transpose(layers.1.self_attn.v_proj.weight, perm=[1,0]) -> _onx_transpose_p_layers_1_self_attn_v_proj_weight0
MatMul(_onx_div__onx_sub_add_700, _onx_transpose_p_layers_1_self_attn_v_proj_weight0) -> _onx_matmul_layer_norm_103
Reshape(_onx_matmul_layer_norm_103, init7_s4_2_1024_-1_16) -> view_7
Transpose(view_7, perm=[0,2,1,3]) -> output_4
Mul(slice_38, unsqueeze_10) -> mul_12
Add(mul_12, mul_13) -> add_8
Concat(add_8, slice_39, axis=-1) -> cat_7
Mul(slice_40, unsqueeze_10) -> mul_14
Add(mul_14, mul_15) -> add_9
Concat(add_9, slice_41, axis=-1) -> output_2
Transpose(output_2, perm=[0,1,3,2]) -> transpose_9
MatMul(cat_7, transpose_9) -> matmul_3
Reshape(init1_s_3, init7_s1_1) -> _onx_reshape_init1_s_302
Mul(matmul_3, _onx_reshape_init1_s_302) -> _onx_mul_matmul_30
Add(_onx_mul_matmul_30, masked_fill) -> add_10
Softmax(add_10, axis=-1) -> softmax_1
MatMul(softmax_1, output_4) -> matmul_4
Transpose(matmul_4, perm=[0,2,1,3]) -> transpose_10
Reshape(transpose_10, init7_s3_2_1024_-1) -> view_8
Reshape(view_8, init7_s2_-1_32) -> MatMulAddPattern--view_8
Transpose(layers.1.mlp.fc1.weight, perm=[1,0]) -> _onx_transpose_p_layers_1_mlp_fc1_weight0
MatMul(_onx_div__onx_sub_add_700, _onx_transpose_p_layers_1_mlp_fc1_weight0) -> _onx_matmul_layer_norm_104
Pow(_onx_matmul_layer_norm_104, init1_s1_4) -> pow_2
Reshape(init1_s_4, init7_s1_1) -> _onx_reshape_init1_s_402
Mul(_onx_matmul_layer_norm_104, _onx_reshape_init1_s_402) -> _onx_mul_linear_100
Reshape(init1_s_5, init7_s1_1) -> _onx_reshape_init1_s_502
Mul(pow_2, _onx_reshape_init1_s_502) -> _onx_mul_pow_20
Add(_onx_matmul_layer_norm_104, _onx_mul_pow_20) -> add_11
Reshape(init1_s_6, init7_s1_1) -> _onx_reshape_init1_s_602
Mul(add_11, _onx_reshape_init1_s_602) -> _onx_mul_add_110
Tanh(_onx_mul_add_110) -> tanh_1
Reshape(init1_s_2, init7_s1_1) -> _onx_reshape_init1_s_204
Add(tanh_1, _onx_reshape_init1_s_204) -> add_12
Mul(_onx_mul_linear_100, add_12) -> mul_20
Transpose(layers.1.mlp.fc2.weight, perm=[1,0]) -> _onx_transpose_p_layers_1_mlp_fc2_weight0
MatMul(mul_20, _onx_transpose_p_layers_1_mlp_fc2_weight0) -> _onx_matmul_mul_200
Reshape(_onx_matmul_mul_200, init7_s2_-1_32) -> MatMulAddPattern--view_82
Gemm(MatMulAddPattern--view_8, layers.1.self_attn.dense.weight, MatMulAddPattern--view_82, transB=1) -> MatMulAddPattern--view_83
Reshape(MatMulAddPattern--view_83, init7_s3_2_1024_-12) -> add_13
Add(add_13, add_7) -> add_14
LayerNormalization(add_14, init1_s32_, init1_s32_2, axis=-1, epsilon=0.00, stash_type=1) -> output_0
output: name='output_0' type=dtype('float32') shape=[2, 1024, 32]
output: name='output_1' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_2' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_3' type=dtype('float32') shape=[2, 2, 1024, 16]
output: name='output_4' type=dtype('float32') shape=[2, 2, 1024, 16]