Note
Go to the end to download the full example code.
torch.onnx.export and Phi-2¶
Exports model Phi-2. We use a dummy model. The main difficulty is to set the dynamic shapes properly.
Model¶
import copy
from typing import Any, Dict
import onnx
import torch
import transformers
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import string_type, pretty_onnx
def get_phi2_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
"""
Gets a non initialized model with its inputs
:param batch_size: batch size
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:return: dictionary
See `Phi-2/config.json
<https://huggingface.co/microsoft/phi-2/blob/main/config.json>`_.
"""
config = {
"_name_or_path": "microsoft/phi-2",
"architectures": ["PhiForCausalLM"],
"attention_dropout": 0.0,
"bos_token_id": 50256,
"embd_pdrop": 0.0,
"eos_token_id": 50256,
"hidden_act": "gelu_new",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 10240,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 2048,
"model_type": "phi",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"partial_rotary_factor": 0.4,
"qk_layernorm": False,
"resid_pdrop": 0.1,
"rope_scaling": None,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.37.0",
"use_cache": True,
"vocab_size": 51200,
}
config.update(**kwargs)
conf = transformers.PhiConfig(**config)
model = transformers.PhiForCausalLM(conf)
model.eval()
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = torch.export.Dim("seq_length", min=1, max=4096)
shapes = {}
cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache.update(
torch.randn(batch_size, 32, 30, 80), torch.randn(batch_size, 32, 30, 80), i
)
cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache2.update(
torch.randn(batch_size + 1, 32, 31, 80),
torch.randn(batch_size + 1, 32, 31, 80),
i,
)
inputs = dict(
input_ids=torch.randint(0, 50285, (batch_size, 3)).to(torch.int64),
attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
past_key_values=cache,
)
inputs2 = dict(
input_ids=torch.randint(0, 50285, (batch_size + 1, 4)).to(torch.int64),
attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
past_key_values=cache2,
)
n = len(cache.key_cache)
cache_length = torch.export.Dim("cache_length", min=1, max=4096)
shapes.update(
{
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {
0: batch,
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
},
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(n)], # 0: batch,
[{0: batch, 2: cache_length} for _ in range(n)], # 0: batch,
],
}
)
return dict(inputs=inputs, model=model, dynamic_shapes=shapes, inputs2=inputs2)
data = get_phi2_untrained(num_hidden_layers=2)
model = data["model"]
inputs = data["inputs"]
dynamic_shapes = data["dynamic_shapes"]
print("inputs", string_type(inputs, with_shape=True))
print("dynamic_shapes", dynamic_shapes)
inputs dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
dynamic_shapes {'input_ids': {0: <class '__main__.batch'>, 1: <class '__main__.seq_length'>}, 'attention_mask': {0: <class '__main__.batch'>, 1: <_DimHint.DYNAMIC: 3>}, 'past_key_values': [[{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}], [{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}]]}
Let’s check it is working. We need to copy the input before calling the model because it modifies the inputs and they are not properly set up when the export starts.
model(**copy.deepcopy(inputs))
CausalLMOutputWithPast(loss=None, logits=tensor([[[ 9.6265e-01, 3.7639e-01, -1.3292e+00, ..., -1.0676e-03,
4.6708e-01, 3.7059e-01],
[ 1.5374e+00, 1.5214e+00, -9.4288e-01, ..., 1.3344e+00,
-4.0183e-01, -1.3003e-02],
[ 6.1904e-02, 1.2486e+00, -1.5581e+00, ..., 1.4394e+00,
1.8182e-01, -9.7775e-01]],
[[-6.1760e-01, 6.2589e-01, -1.4408e+00, ..., -3.2613e-01,
7.2200e-01, -9.4777e-01],
[ 2.0299e-01, 1.7264e+00, -4.9078e-01, ..., -1.0258e+00,
1.3857e+00, 2.7226e-01],
[ 1.1295e-01, -1.7574e-01, -8.3614e-01, ..., -1.0082e+00,
1.8203e+00, 5.9326e-02]]], grad_fn=<ViewBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)
Export¶
Let’s export with torch.onnx.export()
.
try:
torch.onnx.export(
copy.deepcopy(model),
(),
kwargs=copy.deepcopy(inputs),
dynamic_shapes=dynamic_shapes,
dynamo=True,
)
except Exception as e:
print(f"export failed due to {e}")
/home/xadupre/github/onnxscript/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
/home/xadupre/github/onnxscript/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
param_schemas = callee.param_schemas()
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script...
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/modeling_utils.py:5055: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead
warnings.warn(
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:460: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/models/phi/modeling_phi.py:1126: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if sequence_length != 1:
/home/xadupre/vv/this312/lib/python3.12/site-packages/transformers/cache_utils.py:444: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
len(self.key_cache[layer_idx]) == 0
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with Torch Script... ❌
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with internal Dynamo apis...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with internal Dynamo apis... ❌
export failed due to Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
- Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
- Debug `torch.export.export` and summit a PR to PyTorch.
- Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.
## Exception summary
<class 'torch._dynamo.exc.UserError'>: Cannot associate shape [[{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}], [{0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}, {0: <class '__main__.batch'>, 2: <class '__main__.cache_length'>}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
(Refer to the full stack trace above for more information.)
The export fails for a couple of reason but it is possible to patch the
code to make it work. All those modifications are put in place by
onnx_export_errors
and reverted after the export is done. Among other things, this function registers
serialization functions as shown in example
Export a model using a custom type as input.
from experimental_experiment.torch_interpreter.onnx_export_errors import (
bypass_export_some_errors,
)
with bypass_export_some_errors(
patch_transformers=True, replace_dynamic_cache=True, verbose=1
) as modificator:
print("inputs before", string_type(inputs, with_shape=True))
inputs = modificator(inputs)
print("inputs after", string_type(inputs, with_shape=True))
# ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False)
ep = torch.onnx.export(
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
)
ep.optimize()
ep.save("plot_exporter_recipes_oe_phi2.onnx")
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[bypass_export_some_errors] register MambaCache
[bypass_export_some_errors] register DynamicCache
[bypass_export_some_errors] register patched_DynamicCache
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] patch transformers
[bypass_export_some_errors] replace DynamicCache
inputs before dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
inputs after dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:patched_DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PhiForCausalLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 31 of general pattern rewrite rules.
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored transformer
[bypass_export_some_errors] restored DynamicCache
[bypass_export_some_errors] unregistered MambaCache
[bypass_export_some_errors] unregistered DynamicCache
[bypass_export_some_errors] unregistered patched_DynamicCache
Exported Model¶
Let’s display the model.
onx = onnx.load("plot_exporter_recipes_oe_phi2.onnx")
print(pretty_onnx(onx))
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib' version=1
input: name='input_ids' type=dtype('int64') shape=['s0', '3']
input: name='attention_mask' type=dtype('int64') shape=['s0', 's11 + 3']
input: name='past_key_values_key_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_key_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_0' type=dtype('float32') shape=['s0', 32, 's11', 80]
input: name='past_key_values_value_cache_1' type=dtype('float32') shape=['s0', 32, 's11', 80]
init: name='model.embed_tokens.weight' type=float32 shape=(51200, 2560)
init: name='model.layers.0.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.0.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.0.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.0.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.0.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.0.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.0.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.0.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.q_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.q_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.k_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.k_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.v_proj.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.v_proj.bias' type=float32 shape=(2560,)
init: name='model.layers.1.self_attn.dense.weight' type=float32 shape=(2560, 2560)
init: name='model.layers.1.self_attn.dense.bias' type=float32 shape=(2560,)
init: name='model.layers.1.mlp.fc1.weight' type=float32 shape=(10240, 2560)
init: name='model.layers.1.mlp.fc1.bias' type=float32 shape=(10240,)
init: name='model.layers.1.mlp.fc2.weight' type=float32 shape=(2560, 10240)
init: name='model.layers.1.mlp.fc2.bias' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.weight' type=float32 shape=(2560,)
init: name='model.layers.1.input_layernorm.bias' type=float32 shape=(2560,)
init: name='model.final_layernorm.weight' type=float32 shape=(2560,)
init: name='model.final_layernorm.bias' type=float32 shape=(2560,)
init: name='lm_head.weight' type=float32 shape=(51200, 2560)
init: name='lm_head.bias' type=float32 shape=(51200,)
Constant(value_int=1) -> diagonal
Shape(input_ids, end=1, start=0) -> sym_size_int_64
Shape(input_ids, end=2, start=1) -> sym_size_int_65
Mul(sym_size_int_64, sym_size_int_65) -> mul_171
Shape(past_key_values_key_cache_0, end=3, start=2) -> sym_size_int_66
Add(sym_size_int_66, sym_size_int_65) -> add_4
Concat(sym_size_int_65, add_4, axis=0) -> val_4
Gather(model.embed_tokens.weight, input_ids, axis=0) -> embedding
LayerNormalization(embedding, model.layers.0.input_layernorm.weight, model.layers.0.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm
Constant(value=1.0) -> val_0
Constant(value=1) -> val_1
Range(sym_size_int_66, add_4, val_1) -> arange
Constant(value=0) -> dim_0
Unsqueeze(arange, dim_0) -> unsqueeze
Constant(value=-3.4028234...) -> val_3
Expand(val_3, val_4) -> full
Trilu(full, diagonal, upper=1) -> triu
Constant(value=0) -> val_7
Constant(value=1) -> val_8
Range(val_7, add_4, val_8) -> arange_1
Constant(value=[-1, 1]) -> val_10
Reshape(arange, val_10, allowzero=0) -> view
Greater(arange_1, view) -> gt
Cast(gt, to=1) -> convert_element_type_default
Mul(triu, convert_element_type_default) -> mul_16
Constant(value=0) -> dim_0_2
Unsqueeze(mul_16, dim_0_2) -> unsqueeze_3
Constant(value=1) -> dim_0_3
Unsqueeze(unsqueeze_3, dim_0_3) -> unsqueeze_4
Constant(value=0) -> val_11
Constant(value=2) -> val_19
Constant(value=[1]) -> val_35
Constant(value=[-1]) -> val_36
Concat(sym_size_int_64, val_35, val_36, val_36, axis=0) -> val_37
Abs(val_37) -> size_1
Expand(unsqueeze_4, size_1) -> expand_1
Shape(expand_1, start=0) -> val_171
Gather(val_171, val_19, axis=0) -> val_172
Constant(value=1) -> val_54
Range(val_11, val_172, val_54) -> val_173
Constant(value=1) -> dim_0_4
Unsqueeze(attention_mask, dim_0_4) -> unsqueeze_5
Constant(value=2) -> dim_0_5
Unsqueeze(unsqueeze_5, dim_0_5) -> unsqueeze_6
Cast(unsqueeze_6, to=1) -> convert_element_type_default_1
Add(expand_1, convert_element_type_default_1) -> add_84
Constant(value=0.0) -> scalar_tensor_default
Equal(add_84, scalar_tensor_default) -> eq_57
Constant(value=-3.4028234...) -> val_119
Where(eq_57, val_119, expand_1) -> masked_fill
Transpose(masked_fill, perm=[2,1,0,3]) -> val_180
Constant(value=-1) -> val_178
Unsqueeze(val_173, val_178) -> val_179
Transpose(expand_1, perm=[2,1,0,3]) -> val_181
ScatterND(val_181, val_179, val_180, reduction=b'none') -> val_182
Transpose(val_182, perm=[2,1,0,3]) -> slice_scatter
Transpose(slice_scatter, perm=[1,0,2,3]) -> val_192
Shape(expand_1, start=0) -> val_184
Gather(val_184, val_54, axis=0) -> val_185
Range(val_11, val_185, val_54) -> val_186
Unsqueeze(val_186, val_178) -> val_191
Transpose(expand_1, perm=[1,0,2,3]) -> val_193
ScatterND(val_193, val_191, val_192, reduction=b'none') -> val_194
Transpose(val_194, perm=[1,0,2,3]) -> slice_scatter_1
Shape(expand_1, start=0) -> val_196
Gather(val_196, val_11, axis=0) -> val_197
Range(val_11, val_197, val_54) -> val_198
Unsqueeze(val_198, val_178) -> val_203
ScatterND(expand_1, val_203, slice_scatter_1, reduction=b'none') -> slice_scatter_2
Constant(value=1) -> dim_0_8
Unsqueeze(unsqueeze, dim_0_8) -> unsqueeze_9
Cast(unsqueeze_9, to=1) -> _to_copy_1
Constant(value=[[[1.0], [...) -> _to_copy_2
MatMul(_to_copy_2, _to_copy_1) -> matmul
Transpose(matmul, perm=[0,2,1]) -> transpose
Concat(transpose, transpose, axis=-1) -> cat
Cos(cat) -> cos
Sin(cat) -> sin
Constant(value=[2560]) -> val_235
Concat(mul_171, val_235, axis=0) -> val_236
Reshape(layer_norm, val_236, allowzero=0) -> view_1
Transpose(model.layers.0.self_attn.q_proj.weight, perm=[1,0]) -> t
Gemm(view_1, t, model.layers.0.self_attn.q_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_238
Reshape(addmm, val_238, allowzero=0) -> view_2
Concat(mul_171, val_235, axis=0) -> val_240
Reshape(layer_norm, val_240, allowzero=0) -> view_3
Transpose(model.layers.0.self_attn.k_proj.weight, perm=[1,0]) -> t_1
Gemm(view_3, t_1, model.layers.0.self_attn.k_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_1
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_242
Reshape(addmm_1, val_242, allowzero=0) -> view_4
Concat(mul_171, val_235, axis=0) -> val_244
Reshape(layer_norm, val_244, allowzero=0) -> view_5
Transpose(model.layers.0.self_attn.v_proj.weight, perm=[1,0]) -> t_2
Gemm(view_5, t_2, model.layers.0.self_attn.v_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_2
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_246
Reshape(addmm_2, val_246, allowzero=0) -> view_6
Constant(value=[32]) -> val_248
Constant(value=[80]) -> val_249
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_250
Reshape(view_2, val_250, allowzero=0) -> view_7
Transpose(view_7, perm=[0,2,1,3]) -> transpose_1
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_252
Reshape(view_4, val_252, allowzero=0) -> view_8
Transpose(view_8, perm=[0,2,1,3]) -> transpose_2
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_254
Reshape(view_6, val_254, allowzero=0) -> view_9
Transpose(view_9, perm=[0,2,1,3]) -> transpose_3
Concat(past_key_values_value_cache_0, transpose_3, axis=-2) -> cat_6
Constant(value=[0]) -> val_258
Constant(value=[32]) -> val_262
Constant(value=[3]) -> val_265
Constant(value_ints=[1]) -> val_266
Slice(transpose_1, val_258, val_262, val_265, val_266) -> slice_24
Constant(value=[32]) -> val_269
Constant(value=[922337203...) -> val_272
Constant(value=[3]) -> val_275
Constant(value_ints=[1]) -> val_276
Slice(transpose_1, val_269, val_272, val_275, val_276) -> slice_25
Constant(value=[0]) -> val_279
Constant(value=[32]) -> val_282
Constant(value=[3]) -> val_285
Constant(value_ints=[1]) -> val_286
Slice(transpose_2, val_279, val_282, val_285, val_286) -> slice_26
Constant(value=[32]) -> val_289
Constant(value=[922337203...) -> val_292
Constant(value=[3]) -> val_295
Constant(value_ints=[1]) -> val_296
Slice(transpose_2, val_289, val_292, val_295, val_296) -> slice_27
Constant(value=1) -> dim_0_9
Unsqueeze(cos, dim_0_9) -> unsqueeze_10
Mul(slice_24, unsqueeze_10) -> mul_233
Constant(value=1) -> dim_0_10
Unsqueeze(sin, dim_0_10) -> unsqueeze_11
Constant(value=[0]) -> val_299
Constant(value=[16]) -> val_303
Constant(value=[3]) -> val_306
Constant(value_ints=[1]) -> val_307
Slice(slice_24, val_299, val_303, val_306, val_307) -> slice_28
Constant(value=[16]) -> val_310
Constant(value=[922337203...) -> val_313
Constant(value=[3]) -> val_316
Constant(value_ints=[1]) -> val_317
Slice(slice_24, val_310, val_313, val_316, val_317) -> slice_29
Neg(slice_29) -> neg
Concat(neg, slice_28, axis=-1) -> cat_1
Mul(cat_1, unsqueeze_11) -> mul_250
Add(mul_233, mul_250) -> add_306
Concat(add_306, slice_25, axis=-1) -> cat_3
Shape(cat_3, start=0) -> val_368
Mul(slice_26, unsqueeze_10) -> mul_258
Constant(value=[0]) -> val_320
Constant(value=[16]) -> val_323
Constant(value=[3]) -> val_326
Constant(value_ints=[1]) -> val_327
Slice(slice_26, val_320, val_323, val_326, val_327) -> slice_30
Constant(value=[16]) -> val_330
Constant(value=[922337203...) -> val_333
Constant(value=[3]) -> val_336
Constant(value_ints=[1]) -> val_337
Slice(slice_26, val_330, val_333, val_336, val_337) -> slice_31
Neg(slice_31) -> neg_1
Concat(neg_1, slice_30, axis=-1) -> cat_2
Mul(cat_2, unsqueeze_11) -> mul_275
Add(mul_258, mul_275) -> add_342
Concat(add_342, slice_27, axis=-1) -> cat_4
Concat(past_key_values_key_cache_0, cat_4, axis=-2) -> cat_5
Shape(cat_5, start=0) -> val_377
Constant(value_ints=[-1]) -> val_369
Gather(val_368, val_369, axis=0) -> val_370
Cast(val_370, to=1) -> val_371
Sqrt(val_371) -> val_374
Constant(value=1.0) -> val_373
Div(val_373, val_374) -> val_375
Sqrt(val_375) -> val_390
Mul(cat_3, val_390) -> val_391
Constant(value_ints=[9223372036854775807]) -> val_378
Slice(val_377, val_36, val_378) -> val_379
Constant(value=[-2]) -> val_380
Slice(val_377, val_380, val_36) -> val_381
Constant(value_ints=[-9223372036854775808]) -> val_382
Slice(val_377, val_382, val_380) -> val_383
Concat(val_383, val_379, val_381, axis=0) -> val_388
Constant(value_ints=[-1]) -> val_384
Concat(val_384, val_381, val_379, axis=0) -> val_385
Reshape(cat_5, val_385, allowzero=0) -> val_386
Transpose(val_386, perm=[0,2,1]) -> val_387
Reshape(val_387, val_388, allowzero=0) -> val_389
Sqrt(val_375) -> val_392
Mul(val_389, val_392) -> val_393
MatMul(val_391, val_393) -> val_394
Add(val_394, slice_scatter_2) -> val_395
Softmax(val_395, axis=-1) -> val_396
MatMul(val_396, cat_6) -> scaled_dot_product_attention
Transpose(scaled_dot_product_attention, perm=[0,2,1,3]) -> transpose_4
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_399
Reshape(transpose_4, val_399, allowzero=0) -> view_10
Concat(mul_171, val_235, axis=0) -> val_401
Reshape(view_10, val_401, allowzero=0) -> view_11
Transpose(model.layers.0.self_attn.dense.weight, perm=[1,0]) -> t_3
Gemm(view_11, t_3, model.layers.0.self_attn.dense.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_3
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_403
Reshape(addmm_3, val_403, allowzero=0) -> view_12
Concat(mul_171, val_235, axis=0) -> val_405
Reshape(layer_norm, val_405, allowzero=0) -> view_13
Transpose(model.layers.0.mlp.fc1.weight, perm=[1,0]) -> t_4
Gemm(view_13, t_4, model.layers.0.mlp.fc1.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_4
Constant(value=[10240]) -> val_407
Concat(sym_size_int_64, sym_size_int_65, val_407, axis=0) -> val_408
Reshape(addmm_4, val_408, allowzero=0) -> view_14
Constant(value=0.5) -> val_410
Mul(view_14, val_410) -> mul_356
Constant(value=3.0) -> val_411
Pow(view_14, val_411) -> pow_1
Constant(value=0.04471499...) -> val_412
Mul(pow_1, val_412) -> mul_363
Add(view_14, mul_363) -> add_438
Constant(value=0.79788458...) -> val_413
Mul(add_438, val_413) -> mul_370
Tanh(mul_370) -> tanh
Add(tanh, val_0) -> add_451
Mul(mul_356, add_451) -> mul_380
Concat(mul_171, val_407, axis=0) -> val_414
Reshape(mul_380, val_414, allowzero=0) -> view_15
Transpose(model.layers.0.mlp.fc2.weight, perm=[1,0]) -> t_5
Gemm(view_15, t_5, model.layers.0.mlp.fc2.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_5
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_416
Reshape(addmm_5, val_416, allowzero=0) -> view_16
Add(view_12, view_16) -> add_474
Add(add_474, embedding) -> add_479
LayerNormalization(add_479, model.layers.1.input_layernorm.weight, model.layers.1.input_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_1
Concat(mul_171, val_235, axis=0) -> val_418
Reshape(layer_norm_1, val_418, allowzero=0) -> view_17
Transpose(model.layers.1.self_attn.q_proj.weight, perm=[1,0]) -> t_6
Gemm(view_17, t_6, model.layers.1.self_attn.q_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_6
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_420
Reshape(addmm_6, val_420, allowzero=0) -> view_18
Concat(mul_171, val_235, axis=0) -> val_422
Reshape(layer_norm_1, val_422, allowzero=0) -> view_19
Transpose(model.layers.1.self_attn.k_proj.weight, perm=[1,0]) -> t_7
Gemm(view_19, t_7, model.layers.1.self_attn.k_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_7
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_424
Reshape(addmm_7, val_424, allowzero=0) -> view_20
Concat(mul_171, val_235, axis=0) -> val_426
Reshape(layer_norm_1, val_426, allowzero=0) -> view_21
Transpose(model.layers.1.self_attn.v_proj.weight, perm=[1,0]) -> t_8
Gemm(view_21, t_8, model.layers.1.self_attn.v_proj.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_8
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_428
Reshape(addmm_8, val_428, allowzero=0) -> view_22
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_430
Reshape(view_18, val_430, allowzero=0) -> view_23
Transpose(view_23, perm=[0,2,1,3]) -> transpose_5
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_432
Reshape(view_20, val_432, allowzero=0) -> view_24
Transpose(view_24, perm=[0,2,1,3]) -> transpose_6
Concat(sym_size_int_64, sym_size_int_65, val_248, val_249, axis=0) -> val_434
Reshape(view_22, val_434, allowzero=0) -> view_25
Transpose(view_25, perm=[0,2,1,3]) -> transpose_7
Concat(past_key_values_value_cache_1, transpose_7, axis=-2) -> cat_12
Constant(value=[0]) -> val_438
Constant(value=[32]) -> val_441
Constant(value=[3]) -> val_444
Constant(value_ints=[1]) -> val_445
Slice(transpose_5, val_438, val_441, val_444, val_445) -> slice_38
Constant(value=[32]) -> val_448
Constant(value=[922337203...) -> val_451
Constant(value=[3]) -> val_454
Constant(value_ints=[1]) -> val_455
Slice(transpose_5, val_448, val_451, val_454, val_455) -> slice_39
Constant(value=[0]) -> val_458
Constant(value=[32]) -> val_461
Constant(value=[3]) -> val_464
Constant(value_ints=[1]) -> val_465
Slice(transpose_6, val_458, val_461, val_464, val_465) -> slice_40
Constant(value=[32]) -> val_468
Constant(value=[922337203...) -> val_471
Constant(value=[3]) -> val_474
Constant(value_ints=[1]) -> val_475
Slice(transpose_6, val_468, val_471, val_474, val_475) -> slice_41
Constant(value=1) -> dim_0_11
Unsqueeze(cos, dim_0_11) -> unsqueeze_12
Mul(slice_38, unsqueeze_12) -> mul_474
Constant(value=1) -> dim_0_12
Unsqueeze(sin, dim_0_12) -> unsqueeze_13
Constant(value=[0]) -> val_478
Constant(value=[16]) -> val_481
Constant(value=[3]) -> val_484
Constant(value_ints=[1]) -> val_485
Slice(slice_38, val_478, val_481, val_484, val_485) -> slice_42
Constant(value=[16]) -> val_488
Constant(value=[922337203...) -> val_491
Constant(value=[3]) -> val_494
Constant(value_ints=[1]) -> val_495
Slice(slice_38, val_488, val_491, val_494, val_495) -> slice_43
Neg(slice_43) -> neg_2
Concat(neg_2, slice_42, axis=-1) -> cat_7
Mul(cat_7, unsqueeze_13) -> mul_491
Add(mul_474, mul_491) -> add_604
Concat(add_604, slice_39, axis=-1) -> cat_9
Shape(cat_9, start=0) -> val_546
Mul(slice_40, unsqueeze_12) -> mul_499
Constant(value=[0]) -> val_498
Constant(value=[16]) -> val_501
Constant(value=[3]) -> val_504
Constant(value_ints=[1]) -> val_505
Slice(slice_40, val_498, val_501, val_504, val_505) -> slice_44
Constant(value=[16]) -> val_508
Constant(value=[922337203...) -> val_511
Constant(value=[3]) -> val_514
Constant(value_ints=[1]) -> val_515
Slice(slice_40, val_508, val_511, val_514, val_515) -> slice_45
Neg(slice_45) -> neg_3
Concat(neg_3, slice_44, axis=-1) -> cat_8
Mul(cat_8, unsqueeze_13) -> mul_516
Add(mul_499, mul_516) -> add_640
Concat(add_640, slice_41, axis=-1) -> cat_10
Concat(past_key_values_key_cache_1, cat_10, axis=-2) -> cat_11
Shape(cat_11, start=0) -> val_555
Slice(val_555, val_380, val_36) -> val_558
Constant(value_ints=[-1]) -> val_547
Gather(val_546, val_547, axis=0) -> val_548
Cast(val_548, to=1) -> val_549
Sqrt(val_549) -> val_552
Constant(value=1.0) -> val_551
Div(val_551, val_552) -> val_553
Sqrt(val_553) -> val_567
Mul(cat_9, val_567) -> val_568
Constant(value_ints=[9223372036854775807]) -> val_556
Slice(val_555, val_36, val_556) -> val_557
Constant(value_ints=[-9223372036854775808]) -> val_559
Slice(val_555, val_559, val_380) -> val_560
Concat(val_560, val_557, val_558, axis=0) -> val_565
Constant(value_ints=[-1]) -> val_561
Concat(val_561, val_558, val_557, axis=0) -> val_562
Reshape(cat_11, val_562, allowzero=0) -> val_563
Transpose(val_563, perm=[0,2,1]) -> val_564
Reshape(val_564, val_565, allowzero=0) -> val_566
Sqrt(val_553) -> val_569
Mul(val_566, val_569) -> val_570
MatMul(val_568, val_570) -> val_571
Add(val_571, slice_scatter_2) -> val_572
Softmax(val_572, axis=-1) -> val_573
MatMul(val_573, cat_12) -> scaled_dot_product_attention_1
Transpose(scaled_dot_product_attention_1, perm=[0,2,1,3]) -> transpose_8
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_576
Reshape(transpose_8, val_576, allowzero=0) -> view_26
Concat(mul_171, val_235, axis=0) -> val_578
Reshape(view_26, val_578, allowzero=0) -> view_27
Transpose(model.layers.1.self_attn.dense.weight, perm=[1,0]) -> t_9
Gemm(view_27, t_9, model.layers.1.self_attn.dense.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_9
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_580
Reshape(addmm_9, val_580, allowzero=0) -> view_28
Concat(mul_171, val_235, axis=0) -> val_582
Reshape(layer_norm_1, val_582, allowzero=0) -> view_29
Transpose(model.layers.1.mlp.fc1.weight, perm=[1,0]) -> t_10
Gemm(view_29, t_10, model.layers.1.mlp.fc1.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_10
Concat(sym_size_int_64, sym_size_int_65, val_407, axis=0) -> val_584
Reshape(addmm_10, val_584, allowzero=0) -> view_30
Mul(view_30, val_410) -> mul_597
Pow(view_30, val_411) -> pow_2
Mul(pow_2, val_412) -> mul_604
Add(view_30, mul_604) -> add_736
Mul(add_736, val_413) -> mul_611
Tanh(mul_611) -> tanh_1
Add(tanh_1, val_0) -> add_749
Mul(mul_597, add_749) -> mul_621
Concat(mul_171, val_407, axis=0) -> val_586
Reshape(mul_621, val_586, allowzero=0) -> view_31
Transpose(model.layers.1.mlp.fc2.weight, perm=[1,0]) -> t_11
Gemm(view_31, t_11, model.layers.1.mlp.fc2.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_11
Concat(sym_size_int_64, sym_size_int_65, val_235, axis=0) -> val_588
Reshape(addmm_11, val_588, allowzero=0) -> view_32
Add(view_28, view_32) -> add_772
Add(add_772, add_479) -> add_777
LayerNormalization(add_777, model.final_layernorm.weight, model.final_layernorm.bias, epsilon=0.00, axis=-1) -> layer_norm_2
Concat(mul_171, val_235, axis=0) -> val_620
Reshape(layer_norm_2, val_620, allowzero=0) -> view_33
Transpose(lm_head.weight, perm=[1,0]) -> t_12
Gemm(view_33, t_12, lm_head.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_12
Constant(value=[51200]) -> val_622
Concat(sym_size_int_64, sym_size_int_65, val_622, axis=0) -> val_623
Reshape(addmm_12, val_623, allowzero=0) -> view_34
output: name='view_34' type=dtype('float32') shape=['', '', '']
output: name='cat_5' type=dtype('float32') shape=['s0', 32, '', 80]
output: name='cat_6' type=dtype('float32') shape=['s0', 32, '', 80]
output: name='cat_11' type=dtype('float32') shape=['s0', 32, '', 80]
output: name='cat_12' type=dtype('float32') shape=['s0', 32, '', 80]
Visually.
Total running time of the script: (0 minutes 16.675 seconds)
Related examples
to_onnx and submodules from LLMs
torch.onnx.export and a model with a test