Note
Go to the end to download the full example code.
torch.onnx.export and Phi-2¶
Exports model Phi-2. We use a dummy model. The main difficulty is to set the dynamic shapes properly.
Model¶
import copy
from typing import Any, Dict
import onnx
import torch
import transformers
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import string_type, pretty_onnx
def get_phi2_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
"""
Gets a non initialized model with its inputs
:param batch_size: batch size
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:return: dictionary
See `Phi-2/config.json
<https://huggingface.co/microsoft/phi-2/blob/main/config.json>`_.
"""
config = {
"_name_or_path": "microsoft/phi-2",
"architectures": ["PhiForCausalLM"],
"attention_dropout": 0.0,
"bos_token_id": 50256,
"embd_pdrop": 0.0,
"eos_token_id": 50256,
"hidden_act": "gelu_new",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 10240,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 2048,
"model_type": "phi",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"partial_rotary_factor": 0.4,
"qk_layernorm": False,
"resid_pdrop": 0.1,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.37.0",
"use_cache": True,
"vocab_size": 51200,
}
config.update(**kwargs)
conf = transformers.PhiConfig(**config)
model = transformers.PhiForCausalLM(conf)
model.eval()
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
shapes = {}
cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache.update(
torch.randn(batch_size, 32, 30, 80), torch.randn(batch_size, 32, 30, 80), i
)
cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache2.update(
torch.randn(batch_size + 1, 32, 31, 80),
torch.randn(batch_size + 1, 32, 31, 80),
i,
)
inputs = dict(
input_ids=torch.randint(0, 50285, (batch_size, 3)).to(torch.int64),
attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
past_key_values=cache,
)
inputs2 = dict(
input_ids=torch.randint(0, 50285, (batch_size + 1, 4)).to(torch.int64),
attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
past_key_values=cache2,
)
n = len(cache.key_cache)
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
shapes.update(
{
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {
0: batch,
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
},
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(n)], # 0: batch,
[{0: batch, 2: cache_length} for _ in range(n)], # 0: batch,
],
}
)
return dict(inputs=inputs, model=model, dynamic_shapes=shapes, inputs2=inputs2)
data = get_phi2_untrained(num_hidden_layers=2)
model = data["model"]
inputs = data["inputs"]
dynamic_shapes = data["dynamic_shapes"]
print("inputs", string_type(inputs, with_shape=True))
print("dynamic_shapes", dynamic_shapes)
inputs dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
dynamic_shapes {'input_ids': {0: <class '__main__.batch'>, 1: <class '__main__.seq_length'>}, 'attention_mask': {0: <class '__main__.batch'>, 1: <_DimHint.DYNAMIC: 3>}, 'past_key_values': [[{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}], [{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}]]}
Let’s check it is working. We need to copy the input before calling the model because it modifies the inputs and they are not properly set up when the export starts.
model(**copy.deepcopy(inputs))
CausalLMOutputWithPast(loss=None, logits=tensor([[[ 0.7979, 0.3673, 1.7276, ..., 1.0755, -0.5409, 1.7213],
[ 0.0160, -1.5066, 2.6774, ..., -0.3734, -1.6701, 0.8857],
[-0.1865, -0.7282, 1.2860, ..., 1.8946, 0.0776, 0.3918]],
[[-0.5348, 0.1761, 1.3224, ..., -0.0308, 0.6626, -0.8239],
[ 1.2772, -0.2547, 0.2357, ..., 0.4670, -0.8453, 0.2779],
[-0.7813, -1.1209, 1.7710, ..., -0.2175, -1.8416, 0.2227]]],
grad_fn=<ViewBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)
Export¶
Let’s export with torch.onnx.export()
.
try:
torch.onnx.export(
copy.deepcopy(model),
(),
kwargs=copy.deepcopy(inputs),
dynamic_shapes=dynamic_shapes,
dynamo=True,
)
except Exception as e:
print(f"export failed due to {e}")
/home/xadupre/github/onnxscript/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
/home/xadupre/github/onnxscript/onnxscript/converter.py:823: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script...
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:460: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/models/phi/modeling_phi.py:703: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if sequence_length != 1:
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:444: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
len(self.key_cache[layer_idx]) == 0
def forward(self, arg0_1: "f32[51200]", arg1_1: "f32[51200, 2560]", arg2_1: "f32[51200, 2560]", arg3_1: "f32[2560]", arg4_1: "f32[2560]", arg5_1: "f32[2560]", arg6_1: "f32[2560]", arg7_1: "f32[2560]", arg8_1: "f32[2560, 2560]", arg9_1: "f32[2560]", arg10_1: "f32[2560, 2560]", arg11_1: "f32[2560]", arg12_1: "f32[2560, 2560]", arg13_1: "f32[2560]", arg14_1: "f32[2560, 2560]", arg15_1: "f32[10240]", arg16_1: "f32[10240, 2560]", arg17_1: "f32[2560]", arg18_1: "f32[2560, 10240]", arg19_1: "f32[2560]", arg20_1: "f32[2560]", arg21_1: "f32[2560]", arg22_1: "f32[2560, 2560]", arg23_1: "f32[2560]", arg24_1: "f32[2560, 2560]", arg25_1: "f32[2560]", arg26_1: "f32[2560, 2560]", arg27_1: "f32[2560]", arg28_1: "f32[2560, 2560]", arg29_1: "f32[10240]", arg30_1: "f32[10240, 2560]", arg31_1: "f32[2560]", arg32_1: "f32[2560, 10240]", arg33_1: "f64[]", arg34_1: "f64[]", arg35_1: "f64[]", arg36_1: "i64[]", arg37_1: "f64[]", arg38_1: "f32[16]", arg39_1: "i64[s0, s1]", arg40_1: "i64[s2, s3]", arg41_1: "f32[s4, s5, s6, s7]", arg42_1: "f32[s8, s9, s10, s11]", arg43_1: "f32[s12, s13, s14, s15]", arg44_1: "f32[s16, s17, s18, s19]"):
# No stacktrace found for following nodes
embedding: "f32[s0, s1, 2560]" = torch.ops.aten.embedding.default(arg2_1, arg39_1); arg2_1 = None
sym_size: "Sym(s6)" = torch.ops.aten.sym_size.int(arg41_1, 2); sym_size = None
sym_size_int: "Sym(s6)" = torch.ops.aten.sym_size.int(arg41_1, 2); arg41_1 = None
scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.int64)
sym_size_1: "Sym(s1)" = torch.ops.aten.sym_size.int(embedding, 1); embedding = sym_size_1 = None
sym_size_int_1: "Sym(s1)" = torch.ops.aten.sym_size.int(arg39_1, 1); arg39_1 = None
scalar_tensor_1: "i64[]" = torch.ops.aten.scalar_tensor.default(sym_size_int_1, dtype = torch.int64); sym_size_int_1 = None
add: "i64[]" = torch.ops.aten.add.Tensor(scalar_tensor, scalar_tensor_1); scalar_tensor = scalar_tensor_1 = None
item: "Sym(u0)" = torch.ops.aten.item.default(add); add = None
arange = torch.ops.aten.arange.start(sym_size_int, item, device = device(type='cpu'), pin_memory = False); sym_size_int = item = arange = None
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script... ❌
export failed due to Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
- Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
- Debug `torch.export.export` and summit a PR to PyTorch.
- Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.
## Exception summary
<class 'torch._dynamo.exc.UserError'>: Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][1]['input_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['attention_mask'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][1].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][1].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
Suggested fixes:
batch = 2
(Refer to the full stack trace above for more information.)
The export fails for a couple of reason but it is possible to patch the
code to make it work. All those modifications are put in place by
onnx_export_errors
and reverted after the export is done. Among other things, this function registers
serialization functions as shown in example
Export a model using a custom type as input.
from experimental_experiment.torch_interpreter.onnx_export_errors import (
bypass_export_some_errors,
)
with bypass_export_some_errors(
patch_transformers=True, replace_dynamic_cache=True, verbose=1
) as modificator:
print("inputs before", string_type(inputs, with_shape=True))
inputs = modificator(inputs)
print("inputs after", string_type(inputs, with_shape=True))
# ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False)
ep = torch.onnx.export(
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
)
ep.optimize()
ep.save("plot_exporter_recipes_oe_phi2.onnx")
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] modifies shape constraints
[bypass_export_some_errors] patch transformers
[bypass_export_some_errors] replace DynamicCache
inputs before dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
inputs after dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:patched_DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 31 of general pattern rewrite rules.
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored shape constraints
[bypass_export_some_errors] restored transformer
[bypass_export_some_errors] restored DynamicCache
Exported Model¶
Let’s display the model.
onx = onnx.load("plot_exporter_recipes_oe_phi2.onnx")
print(pretty_onnx(onx))
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib' version=1
input: name='input_ids' type=dtype('int64') shape=['s0', 's1']
input: name='attention_mask' type=dtype('int64') shape=['s0', 's1 + s11']
input: name='past_key_values_key_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_key_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
init: name='model.embed_tokens.weight' type=float32 shape=(51200, 2560)
init: name='model.layers.0.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.0.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.0.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.0.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.0.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.1.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.1.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.1.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.1.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.final_layernorm.weight' type=float32 shape=(2560,)
init: name='model.final_layernorm.bias' type=float32 shape=(2560,)
init: name='lm_head.weight' type=float32 shape=(51200, 2560)
init: name='lm_head.bias' type=float32 shape=(51200,)
Constant(value_int=1) -> diagonal
Shape(input_ids, end=1, start=0) -> val_0
Squeeze(val_0) -> sym_size_int_51
Shape(input_ids, end=2, start=1) -> val_1
Squeeze(val_1) -> sym_size_int_52
Shape(past_key_values_key_cache_0, end=3, start=2) -> val_2
Squeeze(val_2) -> sym_size_int_53
Add(sym_size_int_53, sym_size_int_52) -> add_4
Gather(model.embed_tokens.weight, input_ids, axis=0) -> embedding
LayerNormalization(embedding, model.layers.0.input_layernorm.weight, model.layers.0.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm
Constant(value=1.0) -> val_3
Constant(value=1) -> val_4
Range(sym_size_int_53, add_4, val_4) -> arange
Constant(value=0) -> dim_0
Unsqueeze(arange, dim_0) -> unsqueeze
Constant(value=-3.4028234...) -> val_6
Constant(value=[-1]) -> val_7
Reshape(sym_size_int_52, val_7, allowzero=0) -> val_8
Constant(value=[-1]) -> val_9
Reshape(add_4, val_9, allowzero=0) -> val_10
Concat(val_8, val_10, axis=0) -> val_11
Expand(val_6, val_11) -> full
Trilu(full, diagonal, upper=1) -> triu
Constant(value=0) -> val_14
Constant(value=1) -> val_15
Range(val_14, add_4, val_15) -> arange_1
Constant(value=[-1, 1]) -> val_17
Reshape(arange, val_17, allowzero=0) -> view
Greater(arange_1, view) -> gt
Cast(gt, to=1) -> convert_element_type_default
Mul(triu, convert_element_type_default) -> mul_16
Constant(value=0) -> dim_0_2
Unsqueeze(mul_16, dim_0_2) -> unsqueeze_3
Constant(value=1) -> dim_0_3
Unsqueeze(unsqueeze_3, dim_0_3) -> unsqueeze_4
Constant(value=0) -> val_18
Constant(value=2) -> val_26
Constant(value=[-1]) -> val_42
Reshape(sym_size_int_51, val_42, allowzero=0) -> val_43
Constant(value=[1]) -> val_44
Constant(value=[-1]) -> val_45
Concat(val_43, val_44, val_45, val_45, axis=0) -> val_46
Abs(val_46) -> size_1
Expand(unsqueeze_4, size_1) -> expand_1
Shape(expand_1, start=0) -> val_180
Gather(val_180, val_26, axis=0) -> val_181
Constant(value=1) -> val_63
Range(val_18, val_181, val_63) -> val_182
Constant(value=1) -> dim_0_4
Unsqueeze(attention_mask, dim_0_4) -> unsqueeze_5
Constant(value=2) -> dim_0_5
Unsqueeze(unsqueeze_5, dim_0_5) -> unsqueeze_6
Cast(unsqueeze_6, to=1) -> convert_element_type_default_1
Add(expand_1, convert_element_type_default_1) -> add_84
Constant(value=0.0) -> scalar_tensor_default
Equal(add_84, scalar_tensor_default) -> eq_57
Constant(value=-3.4028234...) -> val_128
Where(eq_57, val_128, expand_1) -> masked_fill
Transpose(masked_fill, perm=[2,1,0,3]) -> val_189
Constant(value=-1) -> val_187
Unsqueeze(val_182, val_187) -> val_188
Transpose(expand_1, perm=[2,1,0,3]) -> val_190
ScatterND(val_190, val_188, val_189, reduction=b'none') -> val_191
Transpose(val_191, perm=[2,1,0,3]) -> slice_scatter
Transpose(slice_scatter, perm=[1,0,2,3]) -> val_201
Shape(expand_1, start=0) -> val_193
Gather(val_193, val_63, axis=0) -> val_194
Range(val_18, val_194, val_63) -> val_195
Unsqueeze(val_195, val_187) -> val_200
Transpose(expand_1, perm=[1,0,2,3]) -> val_202
ScatterND(val_202, val_200, val_201, reduction=b'none') -> val_203
Transpose(val_203, perm=[1,0,2,3]) -> slice_scatter_1
Shape(expand_1, start=0) -> val_205
Gather(val_205, val_18, axis=0) -> val_206
Range(val_18, val_206, val_63) -> val_207
Unsqueeze(val_207, val_187) -> val_212
ScatterND(expand_1, val_212, slice_scatter_1, reduction=b'none') -> slice_scatter_2
Constant(value=1) -> dim_0_8
Unsqueeze(unsqueeze, dim_0_8) -> unsqueeze_9
Cast(unsqueeze_9, to=1) -> _to_copy_1
Constant(value=[[[1.0], [...) -> _to_copy_2
MatMul(_to_copy_2, _to_copy_1) -> matmul
Transpose(matmul, perm=[0,2,1]) -> transpose
Concat(transpose, transpose, axis=-1) -> cat
Cos(cat) -> cos
Sin(cat) -> sin
Transpose(model.layers.0.self_attn.q_proj.weight, perm=[1,0]) -> val_244
MatMul(layer_norm, val_244) -> val_245
Add(val_245, model.layers.0.self_attn.q_proj.bias) -> linear
Constant(value=[-1]) -> val_246
Reshape(sym_size_int_51, val_246, allowzero=0) -> val_247
Constant(value=[-1]) -> val_248
Reshape(sym_size_int_52, val_248, allowzero=0) -> val_249
Constant(value=[80]) -> val_250
Concat(val_247, val_249, val_45, val_250, axis=0) -> val_251
Reshape(linear, val_251, allowzero=0) -> view_1
Transpose(view_1, perm=[0,2,1,3]) -> transpose_1
Transpose(model.layers.0.self_attn.k_proj.weight, perm=[1,0]) -> val_253
MatMul(layer_norm, val_253) -> val_254
Add(val_254, model.layers.0.self_attn.k_proj.bias) -> linear_1
Constant(value=[-1]) -> val_255
Reshape(sym_size_int_51, val_255, allowzero=0) -> val_256
Constant(value=[-1]) -> val_257
Reshape(sym_size_int_52, val_257, allowzero=0) -> val_258
Concat(val_256, val_258, val_45, val_250, axis=0) -> val_259
Reshape(linear_1, val_259, allowzero=0) -> view_2
Transpose(view_2, perm=[0,2,1,3]) -> transpose_2
Transpose(model.layers.0.self_attn.v_proj.weight, perm=[1,0]) -> val_261
MatMul(layer_norm, val_261) -> val_262
Add(val_262, model.layers.0.self_attn.v_proj.bias) -> linear_2
Constant(value=[-1]) -> val_263
Reshape(sym_size_int_51, val_263, allowzero=0) -> val_264
Constant(value=[-1]) -> val_265
Reshape(sym_size_int_52, val_265, allowzero=0) -> val_266
Concat(val_264, val_266, val_45, val_250, axis=0) -> val_267
Reshape(linear_2, val_267, allowzero=0) -> view_3
Transpose(view_3, perm=[0,2,1,3]) -> transpose_3
Concat(past_key_values_value_cache_0, transpose_3, axis=-2) -> cat_6
Constant(value=[0]) -> val_271
Constant(value=[32]) -> val_275
Constant(value=[3]) -> val_278
Constant(value_ints=[1]) -> val_279
Slice(transpose_1, val_271, val_275, val_278, val_279) -> slice_24
Constant(value=[32]) -> val_282
Constant(value=[922337203...) -> val_285
Constant(value=[3]) -> val_288
Constant(value_ints=[1]) -> val_289
Slice(transpose_1, val_282, val_285, val_288, val_289) -> slice_25
Constant(value=[0]) -> val_292
Constant(value=[32]) -> val_295
Constant(value=[3]) -> val_298
Constant(value_ints=[1]) -> val_299
Slice(transpose_2, val_292, val_295, val_298, val_299) -> slice_26
Constant(value=[32]) -> val_302
Constant(value=[922337203...) -> val_305
Constant(value=[3]) -> val_308
Constant(value_ints=[1]) -> val_309
Slice(transpose_2, val_302, val_305, val_308, val_309) -> slice_27
Constant(value=1) -> dim_0_9
Unsqueeze(cos, dim_0_9) -> unsqueeze_10
Mul(slice_24, unsqueeze_10) -> mul_214
Constant(value=1) -> dim_0_10
Unsqueeze(sin, dim_0_10) -> unsqueeze_11
Constant(value=[0]) -> val_312
Constant(value=[16]) -> val_316
Constant(value=[3]) -> val_319
Constant(value_ints=[1]) -> val_320
Slice(slice_24, val_312, val_316, val_319, val_320) -> slice_28
Constant(value=[16]) -> val_323
Constant(value=[922337203...) -> val_326
Constant(value=[3]) -> val_329
Constant(value_ints=[1]) -> val_330
Slice(slice_24, val_323, val_326, val_329, val_330) -> slice_29
Neg(slice_29) -> neg
Concat(neg, slice_28, axis=-1) -> cat_1
Mul(cat_1, unsqueeze_11) -> mul_231
Add(mul_214, mul_231) -> add_288
Concat(add_288, slice_25, axis=-1) -> cat_3
Mul(slice_26, unsqueeze_10) -> mul_239
Constant(value=[0]) -> val_333
Constant(value=[16]) -> val_336
Constant(value=[3]) -> val_339
Constant(value_ints=[1]) -> val_340
Slice(slice_26, val_333, val_336, val_339, val_340) -> slice_30
Constant(value=[16]) -> val_343
Constant(value=[922337203...) -> val_346
Constant(value=[3]) -> val_349
Constant(value_ints=[1]) -> val_350
Slice(slice_26, val_343, val_346, val_349, val_350) -> slice_31
Neg(slice_31) -> neg_1
Concat(neg_1, slice_30, axis=-1) -> cat_2
Mul(cat_2, unsqueeze_11) -> mul_256
Add(mul_239, mul_256) -> add_324
Concat(add_324, slice_27, axis=-1) -> cat_4
Concat(past_key_values_key_cache_0, cat_4, axis=-2) -> cat_5
Shape(cat_5, start=0) -> val_383
Constant(value_ints=[9223372036854775807]) -> val_384
Slice(val_383, val_45, val_384) -> val_385
Constant(value=[-2]) -> val_386
Slice(val_383, val_386, val_45) -> val_387
Constant(value_ints=[-9223372036854775808]) -> val_388
Slice(val_383, val_388, val_386) -> val_389
Concat(val_389, val_385, val_387, axis=0) -> val_394
Constant(value_ints=[-1]) -> val_390
Concat(val_390, val_387, val_385, axis=0) -> val_391
Reshape(cat_5, val_391, allowzero=0) -> val_392
Transpose(val_392, perm=[0,2,1]) -> val_393
Reshape(val_393, val_394, allowzero=0) -> val_395
Constant(value=0.33437013...) -> val_396
Mul(cat_3, val_396) -> val_397
Constant(value=0.33437013...) -> val_398
Mul(val_395, val_398) -> val_399
MatMul(val_397, val_399) -> val_400
Add(val_400, slice_scatter_2) -> val_401
Softmax(val_401, axis=-1) -> val_402
MatMul(val_402, cat_6) -> scaled_dot_product_attention
Transpose(scaled_dot_product_attention, perm=[0,2,1,3]) -> transpose_4
Constant(value=[-1]) -> val_405
Reshape(sym_size_int_51, val_405, allowzero=0) -> val_406
Constant(value=[-1]) -> val_407
Reshape(sym_size_int_52, val_407, allowzero=0) -> val_408
Concat(val_406, val_408, val_45, axis=0) -> val_409
Reshape(transpose_4, val_409, allowzero=0) -> view_4
Transpose(model.layers.0.self_attn.dense.weight, perm=[1,0]) -> val_411
MatMul(view_4, val_411) -> val_412
Add(val_412, model.layers.0.self_attn.dense.bias) -> linear_3
Transpose(model.layers.0.mlp.fc1.weight, perm=[1,0]) -> val_413
MatMul(layer_norm, val_413) -> val_414
Add(val_414, model.layers.0.mlp.fc1.bias) -> linear_4
Constant(value=0.5) -> val_415
Mul(linear_4, val_415) -> mul_323
Constant(value=3.0) -> val_416
Pow(linear_4, val_416) -> pow_1
Constant(value=0.04471499...) -> val_417
Mul(pow_1, val_417) -> mul_330
Add(linear_4, mul_330) -> add_408
Constant(value=0.79788458...) -> val_418
Mul(add_408, val_418) -> mul_337
Tanh(mul_337) -> tanh
Add(tanh, val_3) -> add_421
Mul(mul_323, add_421) -> mul_347
Transpose(model.layers.0.mlp.fc2.weight, perm=[1,0]) -> val_419
MatMul(mul_347, val_419) -> val_420
Add(val_420, model.layers.0.mlp.fc2.bias) -> linear_5
Add(linear_3, linear_5) -> add_438
Add(add_438, embedding) -> add_443
LayerNormalization(add_443, model.layers.1.input_layernorm.weight, model.layers.1.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_1
Transpose(model.layers.1.self_attn.q_proj.weight, perm=[1,0]) -> val_421
MatMul(layer_norm_1, val_421) -> val_422
Add(val_422, model.layers.1.self_attn.q_proj.bias) -> linear_6
Constant(value=[-1]) -> val_423
Reshape(sym_size_int_51, val_423, allowzero=0) -> val_424
Constant(value=[-1]) -> val_425
Reshape(sym_size_int_52, val_425, allowzero=0) -> val_426
Concat(val_424, val_426, val_45, val_250, axis=0) -> val_427
Reshape(linear_6, val_427, allowzero=0) -> view_5
Transpose(view_5, perm=[0,2,1,3]) -> transpose_5
Transpose(model.layers.1.self_attn.k_proj.weight, perm=[1,0]) -> val_429
MatMul(layer_norm_1, val_429) -> val_430
Add(val_430, model.layers.1.self_attn.k_proj.bias) -> linear_7
Constant(value=[-1]) -> val_431
Reshape(sym_size_int_51, val_431, allowzero=0) -> val_432
Constant(value=[-1]) -> val_433
Reshape(sym_size_int_52, val_433, allowzero=0) -> val_434
Concat(val_432, val_434, val_45, val_250, axis=0) -> val_435
Reshape(linear_7, val_435, allowzero=0) -> view_6
Transpose(view_6, perm=[0,2,1,3]) -> transpose_6
Transpose(model.layers.1.self_attn.v_proj.weight, perm=[1,0]) -> val_437
MatMul(layer_norm_1, val_437) -> val_438
Add(val_438, model.layers.1.self_attn.v_proj.bias) -> linear_8
Constant(value=[-1]) -> val_439
Reshape(sym_size_int_51, val_439, allowzero=0) -> val_440
Constant(value=[-1]) -> val_441
Reshape(sym_size_int_52, val_441, allowzero=0) -> val_442
Concat(val_440, val_442, val_45, val_250, axis=0) -> val_443
Reshape(linear_8, val_443, allowzero=0) -> view_7
Transpose(view_7, perm=[0,2,1,3]) -> transpose_7
Concat(past_key_values_value_cache_1, transpose_7, axis=-2) -> cat_12
Constant(value=[0]) -> val_447
Constant(value=[32]) -> val_450
Constant(value=[3]) -> val_453
Constant(value_ints=[1]) -> val_454
Slice(transpose_5, val_447, val_450, val_453, val_454) -> slice_38
Constant(value=[32]) -> val_457
Constant(value=[922337203...) -> val_460
Constant(value=[3]) -> val_463
Constant(value_ints=[1]) -> val_464
Slice(transpose_5, val_457, val_460, val_463, val_464) -> slice_39
Constant(value=[0]) -> val_467
Constant(value=[32]) -> val_470
Constant(value=[3]) -> val_473
Constant(value_ints=[1]) -> val_474
Slice(transpose_6, val_467, val_470, val_473, val_474) -> slice_40
Constant(value=[32]) -> val_477
Constant(value=[922337203...) -> val_480
Constant(value=[3]) -> val_483
Constant(value_ints=[1]) -> val_484
Slice(transpose_6, val_477, val_480, val_483, val_484) -> slice_41
Constant(value=1) -> dim_0_11
Unsqueeze(cos, dim_0_11) -> unsqueeze_12
Mul(slice_38, unsqueeze_12) -> mul_413
Constant(value=1) -> dim_0_12
Unsqueeze(sin, dim_0_12) -> unsqueeze_13
Constant(value=[0]) -> val_487
Constant(value=[16]) -> val_490
Constant(value=[3]) -> val_493
Constant(value_ints=[1]) -> val_494
Slice(slice_38, val_487, val_490, val_493, val_494) -> slice_42
Constant(value=[16]) -> val_497
Constant(value=[922337203...) -> val_500
Constant(value=[3]) -> val_503
Constant(value_ints=[1]) -> val_504
Slice(slice_38, val_497, val_500, val_503, val_504) -> slice_43
Neg(slice_43) -> neg_2
Concat(neg_2, slice_42, axis=-1) -> cat_7
Mul(cat_7, unsqueeze_13) -> mul_430
Add(mul_413, mul_430) -> add_550
Concat(add_550, slice_39, axis=-1) -> cat_9
Mul(slice_40, unsqueeze_12) -> mul_438
Constant(value=[0]) -> val_507
Constant(value=[16]) -> val_510
Constant(value=[3]) -> val_513
Constant(value_ints=[1]) -> val_514
Slice(slice_40, val_507, val_510, val_513, val_514) -> slice_44
Constant(value=[16]) -> val_517
Constant(value=[922337203...) -> val_520
Constant(value=[3]) -> val_523
Constant(value_ints=[1]) -> val_524
Slice(slice_40, val_517, val_520, val_523, val_524) -> slice_45
Neg(slice_45) -> neg_3
Concat(neg_3, slice_44, axis=-1) -> cat_8
Mul(cat_8, unsqueeze_13) -> mul_455
Add(mul_438, mul_455) -> add_586
Concat(add_586, slice_41, axis=-1) -> cat_10
Concat(past_key_values_key_cache_1, cat_10, axis=-2) -> cat_11
Shape(cat_11, start=0) -> val_556
Slice(val_556, val_386, val_45) -> val_559
Constant(value_ints=[9223372036854775807]) -> val_557
Slice(val_556, val_45, val_557) -> val_558
Constant(value_ints=[-9223372036854775808]) -> val_560
Slice(val_556, val_560, val_386) -> val_561
Concat(val_561, val_558, val_559, axis=0) -> val_566
Constant(value_ints=[-1]) -> val_562
Concat(val_562, val_559, val_558, axis=0) -> val_563
Reshape(cat_11, val_563, allowzero=0) -> val_564
Transpose(val_564, perm=[0,2,1]) -> val_565
Reshape(val_565, val_566, allowzero=0) -> val_567
Constant(value=0.33437013...) -> val_568
Mul(cat_9, val_568) -> val_569
Constant(value=0.33437013...) -> val_570
Mul(val_567, val_570) -> val_571
MatMul(val_569, val_571) -> val_572
Add(val_572, slice_scatter_2) -> val_573
Softmax(val_573, axis=-1) -> val_574
MatMul(val_574, cat_12) -> scaled_dot_product_attention_1
Transpose(scaled_dot_product_attention_1, perm=[0,2,1,3]) -> transpose_8
Constant(value=[-1]) -> val_577
Reshape(sym_size_int_51, val_577, allowzero=0) -> val_578
Constant(value=[-1]) -> val_579
Reshape(sym_size_int_52, val_579, allowzero=0) -> val_580
Concat(val_578, val_580, val_45, axis=0) -> val_581
Reshape(transpose_8, val_581, allowzero=0) -> view_8
Transpose(model.layers.1.self_attn.dense.weight, perm=[1,0]) -> val_583
MatMul(view_8, val_583) -> val_584
Add(val_584, model.layers.1.self_attn.dense.bias) -> linear_9
Transpose(model.layers.1.mlp.fc1.weight, perm=[1,0]) -> val_585
MatMul(layer_norm_1, val_585) -> val_586
Add(val_586, model.layers.1.mlp.fc1.bias) -> linear_10
Mul(linear_10, val_415) -> mul_522
Pow(linear_10, val_416) -> pow_2
Mul(pow_2, val_417) -> mul_529
Add(linear_10, mul_529) -> add_670
Mul(add_670, val_418) -> mul_536
Tanh(mul_536) -> tanh_1
Add(tanh_1, val_3) -> add_683
Mul(mul_522, add_683) -> mul_546
Transpose(model.layers.1.mlp.fc2.weight, perm=[1,0]) -> val_587
MatMul(mul_546, val_587) -> val_588
Add(val_588, model.layers.1.mlp.fc2.bias) -> linear_11
Add(linear_9, linear_11) -> add_700
Add(add_700, add_443) -> add_705
LayerNormalization(add_705, model.final_layernorm.weight, model.final_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_2
Transpose(lm_head.weight, perm=[1,0]) -> val_619
MatMul(layer_norm_2, val_619) -> val_620
Add(val_620, lm_head.bias) -> linear_12
output: name='linear_12' type=dtype('float32') shape=['s0', 's1', 51200]
output: name='cat_5' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
output: name='cat_11' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
output: name='cat_6' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
output: name='cat_12' type=dtype('float32') shape=['s0', 32, 's1 + s11', 80]
Visually.

Total running time of the script: (0 minutes 24.413 seconds)
Related examples

Export Phi-3.5-mini-instruct with report_exportability